@@ -94,8 +94,63 @@ def factory():
94
94
return (yield from loop .create_server (factory , host , port , ** kwds ))
95
95
96
96
97
- class StreamReaderProtocol (protocols .Protocol ):
98
- """Trivial helper class to adapt between Protocol and StreamReader.
97
+ class FlowControlMixin (protocols .Protocol ):
98
+ """Reusable flow control logic for StreamWriter.drain().
99
+
100
+ This implements the protocol methods pause_writing(),
101
+ resume_reading() and connection_lost(). If the subclass overrides
102
+ these it must call the super methods.
103
+
104
+ StreamWriter.drain() must check for error conditions and then call
105
+ _make_drain_waiter(), which will return either () or a Future
106
+ depending on the paused state.
107
+ """
108
+
109
+ def __init__ (self , loop = None ):
110
+ self ._loop = loop # May be None; we may never need it.
111
+ self ._paused = False
112
+ self ._drain_waiter = None
113
+
114
+ def pause_writing (self ):
115
+ assert not self ._paused
116
+ self ._paused = True
117
+
118
+ def resume_writing (self ):
119
+ assert self ._paused
120
+ self ._paused = False
121
+ waiter = self ._drain_waiter
122
+ if waiter is not None :
123
+ self ._drain_waiter = None
124
+ if not waiter .done ():
125
+ waiter .set_result (None )
126
+
127
+ def connection_lost (self , exc ):
128
+ # Wake up the writer if currently paused.
129
+ if not self ._paused :
130
+ return
131
+ waiter = self ._drain_waiter
132
+ if waiter is None :
133
+ return
134
+ self ._drain_waiter = None
135
+ if waiter .done ():
136
+ return
137
+ if exc is None :
138
+ waiter .set_result (None )
139
+ else :
140
+ waiter .set_exception (exc )
141
+
142
+ def _make_drain_waiter (self ):
143
+ if not self ._paused :
144
+ return ()
145
+ waiter = self ._drain_waiter
146
+ assert waiter is None or waiter .cancelled ()
147
+ waiter = futures .Future (loop = self ._loop )
148
+ self ._drain_waiter = waiter
149
+ return waiter
150
+
151
+
152
+ class StreamReaderProtocol (FlowControlMixin , protocols .Protocol ):
153
+ """Helper class to adapt between Protocol and StreamReader.
99
154
100
155
(This is a helper class instead of making StreamReader itself a
101
156
Protocol subclass, because the StreamReader has other potential
@@ -104,12 +159,10 @@ class StreamReaderProtocol(protocols.Protocol):
104
159
"""
105
160
106
161
def __init__ (self , stream_reader , client_connected_cb = None , loop = None ):
162
+ super ().__init__ (loop = loop )
107
163
self ._stream_reader = stream_reader
108
164
self ._stream_writer = None
109
- self ._drain_waiter = None
110
- self ._paused = False
111
165
self ._client_connected_cb = client_connected_cb
112
- self ._loop = loop # May be None; we may never need it.
113
166
114
167
def connection_made (self , transport ):
115
168
self ._stream_reader .set_transport (transport )
@@ -127,36 +180,14 @@ def connection_lost(self, exc):
127
180
self ._stream_reader .feed_eof ()
128
181
else :
129
182
self ._stream_reader .set_exception (exc )
130
- # Also wake up the writing side.
131
- if self ._paused :
132
- waiter = self ._drain_waiter
133
- if waiter is not None :
134
- self ._drain_waiter = None
135
- if not waiter .done ():
136
- if exc is None :
137
- waiter .set_result (None )
138
- else :
139
- waiter .set_exception (exc )
183
+ super ().connection_lost (exc )
140
184
141
185
def data_received (self , data ):
142
186
self ._stream_reader .feed_data (data )
143
187
144
188
def eof_received (self ):
145
189
self ._stream_reader .feed_eof ()
146
190
147
- def pause_writing (self ):
148
- assert not self ._paused
149
- self ._paused = True
150
-
151
- def resume_writing (self ):
152
- assert self ._paused
153
- self ._paused = False
154
- waiter = self ._drain_waiter
155
- if waiter is not None :
156
- self ._drain_waiter = None
157
- if not waiter .done ():
158
- waiter .set_result (None )
159
-
160
191
161
192
class StreamWriter :
162
193
"""Wraps a Transport.
@@ -211,17 +242,11 @@ def drain(self):
211
242
completed, which will happen when the buffer is (partially)
212
243
drained and the protocol is resumed.
213
244
"""
214
- if self ._reader ._exception is not None :
245
+ if self ._reader is not None and self . _reader ._exception is not None :
215
246
raise self ._reader ._exception
216
247
if self ._transport ._conn_lost : # Uses private variable.
217
248
raise ConnectionResetError ('Connection lost' )
218
- if not self ._protocol ._paused :
219
- return ()
220
- waiter = self ._protocol ._drain_waiter
221
- assert waiter is None or waiter .cancelled ()
222
- waiter = futures .Future (loop = self ._loop )
223
- self ._protocol ._drain_waiter = waiter
224
- return waiter
249
+ return self ._protocol ._make_drain_waiter ()
225
250
226
251
227
252
class StreamReader :
0 commit comments