8000 Refactor drain logic in streams.py to be reusable. · python/asyncio@7220ae3 · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Nov 23, 2017. It is now read-only.

Commit 7220ae3

Browse files
committed
Refactor drain logic in streams.py to be reusable.
1 parent 34acced commit 7220ae3

File tree

1 file changed

+61
-36
lines changed

1 file changed

+61
-36
lines changed

asyncio/streams.py

Lines changed: 61 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,63 @@ def factory():
9494
return (yield from loop.create_server(factory, host, port, **kwds))
9595

9696

97-
class StreamReaderProtocol(protocols.Protocol):
98-
"""Trivial helper class to adapt between Protocol and StreamReader.
97+
class FlowControlMixin(protocols.Protocol):
98+
"""Reusable flow control logic for StreamWriter.drain().
99+
100+
This implements the protocol methods pause_writing(),
101+
resume_reading() and connection_lost(). If the subclass overrides
102+
these it must call the super methods.
103+
104+
StreamWriter.drain() must check for error conditions and then call
105+
_make_drain_waiter(), which will return either () or a Future
106+
depending on the paused state.
107+
"""
108+
109+
def __init__(self, loop=None):
110+
self._loop = loop # May be None; we may never need it.
111+
self._paused = False
112+
self._drain_waiter = None
113+
114+
def pause_writing(self):
115+
assert not self._paused
116+
self._paused = True
117+
118+
def resume_writing(self):
119+
assert self._paused
120+
self._paused = False
121+
waiter = self._drain_waiter
122+
if waiter is not None:
123+
self._drain_waiter = None
124+
if not waiter.done():
125+
waiter.set_result(None)
126+
127+
def connection_lost(self, exc):
128+
# Wake up the writer if currently paused.
129+
if not self._paused:
130+
return
131+
waiter = self._drain_waiter
132+
if waiter is None:
133+
return
134+
self._drain_waiter = None
135+
if waiter.done():
136+
return
137+
if exc is None:
138+
waiter.set_result(None)
139+
else:
140+
waiter.set_exception(exc)
141+
142+
def _make_drain_waiter(self):
143+
if not self._paused:
144+
return ()
145+
waiter = self._drain_waiter
146+
assert waiter is None or waiter.cancelled()
147+
waiter = futures.Future(loop=self._loop)
148+
self._drain_waiter = waiter
149+
return waiter
150+
151+
152+
class StreamReaderProtocol(FlowControlMixin, protocols.Protocol):
153+
"""Helper class to adapt between Protocol and StreamReader.
99154
100155
(This is a helper class instead of making StreamReader itself a
101156
Protocol subclass, because the StreamReader has other potential
@@ -104,12 +159,10 @@ class StreamReaderProtocol(protocols.Protocol):
104159
"""
105160

106161
def __init__(self, stream_reader, client_connected_cb=None, loop=None):
162+
super().__init__(loop=loop)
107163
self._stream_reader = stream_reader
108164
self._stream_writer = None
109-
self._drain_waiter = None
110-
self._paused = False
111165
self._client_connected_cb = client_connected_cb
112-
self._loop = loop # May be None; we may never need it.
113166

114167
def connection_made(self, transport):
115168
self._stream_reader.set_transport(transport)
@@ -127,36 +180,14 @@ def connection_lost(self, exc):
127180
self._stream_reader.feed_eof()
128181
else:
129182
self._stream_reader.set_exception(exc)
130-
# Also wake up the writing side.
131-
if self._paused:
132-
waiter = self._drain_waiter
133-
if waiter is not None:
134-
self._drain_waiter = None
135-
if not waiter.done():
136-
if exc is None:
137-
waiter.set_result(None)
138-
else:
139-
waiter.set_exception(exc)
183+
super().connection_lost(exc)
140184

141185
def data_received(self, data):
142186
self._stream_reader.feed_data(data)
143187

144188
def eof_received(self):
145189
self._stream_reader.feed_eof()
146190

147-
def pause_writing(self):
148-
assert not self._paused
149-
self._paused = True
150-
151-
def resume_writing(self):
152-
assert self._paused
153-
self._paused = False
154-
waiter = self._drain_waiter
155-
if waiter is not None:
156-
self._drain_waiter = None
157-
if not waiter.done():
158-
waiter.set_result(None)
159-
160191

161192
class StreamWriter:
162193
"""Wraps a Transport.
@@ -211,17 +242,11 @@ def drain(self):
211242
completed, which will happen when the buffer is (partially)
212243
drained and the protocol is resumed.
213244
"""
214-
if self._reader._exception is not None:
245+
if self._reader is not None and self._reader._exception is not None:
215246
raise self._reader._exception
216247
if self._transport._conn_lost: # Uses private variable.
217248
raise ConnectionResetError('Connection lost')
218-
if not self._protocol._paused:
219-
return ()
220-
waiter = self._protocol._drain_waiter
221-
assert waiter is None or waiter.cancelled()
222-
waiter = futures.Future(loop=self._loop)
223-
self._protocol._drain_waiter = waiter
224-
return waiter
249+
return self._protocol._make_drain_waiter()
225250

226251

227252
class StreamReader:

0 commit comments

Comments
 (0)
0