@@ -302,6 +302,7 @@ def test_start_tls_client_buf_proto_1(self):
302
302
303
303
server_context = test_utils .simple_server_sslcontext ()
304
304
client_context = test_utils .simple_client_sslcontext ()
305
+ client_con_made_calls = 0
305
306
306
307
def serve (sock ):
307
308
sock .settimeout (self .TIMEOUT )
@@ -315,20 +316,21 @@ def serve(sock):
315
316
data = sock .recv_all (len (HELLO_MSG ))
316
317
self .assertEqual (len (data ), len (HELLO_MSG ))
317
318
319
+ sock .sendall (b'2' )
320
+ data = sock .recv_all (len (HELLO_MSG ))
321
+ self .assertEqual (len (data ), len (HELLO_MSG ))
322
+
318
323
sock .shutdown (socket .SHUT_RDWR )
319
324
sock .close ()
320
325
321
- class ClientProto (asyncio .BufferedProtocol ):
322
- def __init__ (self , on_data , on_eof ):
326
+ class ClientProtoFirst (asyncio .BufferedProtocol ):
327
+ def __init__ (self , on_data ):
323
328
self .on_data = on_data
324
- self .on_eof = on_eof
325
- self .con_made_cnt = 0
326
329
self .buf = bytearray (1 )
327
330
328
- def connection_made (proto , tr ):
329
- proto .con_made_cnt += 1
330
- # Ensure connection_made gets called only once.
331
- self .assertEqual (proto .con_made_cnt , 1 )
331
+ def connection_made (self , tr ):
332
+ nonlocal client_con_made_calls
333
+ client_con_made_calls += 1
332
334
333
335
def get_buffer (self , sizehint ):
334
336
return self .buf
@@ -337,27 +339,50 @@ def buffer_updated(self, nsize):
337
339
assert nsize == 1
338
340
self .on_data .set_result (bytes (self .buf [:nsize ]))
339
341
342
+ class ClientProtoSecond (asyncio .Protocol ):
343
+ def __init__ (self , on_data , on_eof ):
344
+ self .on_data = on_data
345
+ self .on_eof = on_eof
346
+ self .con_made_cnt = 0
347
+
348
+ def connection_made (self , tr ):
349
+ nonlocal client_con_made_calls
350
+ client_con_made_calls += 1
351
+
352
+ def data_received (self , data ):
353
+ self .on_data .set_result (data )
354
+
340
355
def eof_received (self ):
341
356
self .on_eof .set_result (True )
342
357
343
358
async def client (addr ):
344
359
await asyncio .sleep (0.5 , loop = self .loop )
345
360
346
- on_data = self .loop .create_future ()
361
+ on_data1 = self .loop .create_future ()
362
+ on_data2 = self .loop .create_future ()
347
363
on_eof = self .loop .create_future ()
348
364
349
365
tr , proto = await self .loop .create_connection (
350
- lambda : ClientProto ( on_data , on_eof ), * addr )
366
+ lambda : ClientProtoFirst ( on_data1 ), * addr )
351
367
352
368
tr .write (HELLO_MSG )
353
369
new_tr = await self .loop .start_tls (tr , proto , client_context )
354
370
355
- self .assertEqual (await on_data , b'O' )
371
+ self .assertEqual (await on_data1 , b'O' )
372
+ new_tr .write (HELLO_MSG )
373
+
374
+ new_tr .set_protocol (ClientProtoSecond (on_data2 , on_eof ))
375
+ self .assertEqual (await on_data2 , b'2' )
356
376
new_tr .write (HELLO_MSG )
357
377
await on_eof
358
378
359
379
new_tr .close ()
360
380
381
+ # connection_made() should be called only once -- when
382
+ # we establish connection for the first time. Start TLS
383
+ # doesn't call connection_made() on application protocols.
384
+ self .assertEqual (client_con_made_calls , 1 )
385
+
361
386
with self .tcp_server (serve , timeout = self .TIMEOUT ) as srv :
362
387
self .loop .run_until_complete (
363
388
asyncio .wait_for (client (srv .addr ),
0 commit comments