8000 Raise RuntimeError when transport's FD is used with add_reader etc · python/asyncio@374930d · GitHub
[go: up one dir, main page]

Skip to content
This repository was archived by the owner on Nov 23, 2017. It is now read-only.

Commit 374930d

Browse files
committed
Raise RuntimeError when transport's FD is used with add_reader etc
PR #420
1 parent 496f60f commit 374930d

File tree

5 files changed

+170
-105
lines changed
  • tests
  • 5 files changed

    +170
    -105
    lines changed

    asyncio/selector_events.py

    Lines changed: 79 additions & 48 deletions
    Original file line numberDiff line numberDiff line change
    @@ -11,6 +11,7 @@
    1111
    import functools
    1212
    import socket
    1313
    import warnings
    14+
    import weakref
    1415
    try:
    1516
    import ssl
    1617
    except ImportError: # pragma: no cover
    @@ -64,6 +65,7 @@ def __init__(self, selector=None):
    6465
    logger.debug('Using selector: %s', selector.__class__.__name__)
    6566
    self._selector = selector
    6667
    self._make_self_pipe()
    68+
    self._transports = weakref.WeakValueDictionary()
    6769

    6870
    def _make_socket_transport(self, sock, protocol, waiter=None, *,
    6971
    extra=None, server=None):
    @@ -115,7 +117,7 @@ def _socketpair(self):
    115117
    raise NotImplementedError
    116118

    117119
    def _close_self_pipe(self):
    118-
    self.remove_reader(self._ssock.fileno())
    120+
    self._remove_reader(self._ssock.fileno())
    119121
    self._ssock.close()
    120122
    self._ssock = None
    121123
    self._csock.close()
    @@ -128,7 +130,7 @@ def _make_self_pipe(self):
    128130
    self._ssock.setblocking(False)
    129131
    self._csock.setblocking(False)
    130132
    self._internal_fds += 1
    131-
    self.add_reader(self._ssock.fileno(), self._read_from_self)
    133+
    self._add_reader(self._ssock.fileno(), self._read_from_self)
    132134

    133135
    def _process_self_data(self, data):
    134136
    pass
    @@ -163,8 +165,8 @@ def _write_to_self(self):
    163165

    164166
    def _start_serving(self, protocol_factory, sock,
    165167
    sslcontext=None, server=None, backlog=100):
    166-
    self.add_reader(sock.fileno(), self._accept_connection,
    167-
    protocol_factory, sock, sslcontext, server, backlog)
    168+
    self._add_reader(sock.fileno(), self._accept_connection,
    169+
    protocol_factory, sock, sslcontext, server, backlog)
    168170

    169171
    def _accept_connection(self, protocol_factory, sock,
    170172
    sslcontext=None, server=None, backlog=100):
    @@ -194,7 +196,7 @@ def _accept_connection(self, protocol_factory, sock,
    194196
    'exception': exc,
    195197
    'socket': sock,
    196198
    })
    197-
    self.remove_reader(sock.fileno())
    199+
    self._remove_reader(sock.fileno())
    198200
    self.call_later(constants.ACCEPT_RETRY_DELAY,
    199201
    self._start_serving,
    200202
    protocol_factory, sock, sslcontext, server,
    @@ -244,8 +246,18 @@ def _accept_connection2(self, protocol_factory, conn, extra,
    244246
    context['transport'] = transport
    245247
    self.call_exception_handler(context)
    246248

    247-
    def add_reader(self, fd, callback, *args):
    248-
    """Add a reader callback."""
    249+
    def _ensure_fd_no_transport(self, fd):
    250+
    try:
    251+
    transport = self._transports[fd]
    252+
    except KeyError:
    253+
    pass
    254+
    else:
    255+
    if not transport.is_closing():
    256+
    raise RuntimeError(
    257+
    'File descriptor {!r} is used by transport {!r}'.format(
    258+
    fd, transport))
    259+
    260+
    def _add_reader(self, fd, callback, *args):
    249261
    self._check_closed()
    250262
    handle = events.Handle(callback, args, self)
    251263
    try:
    @@ -260,8 +272,7 @@ def add_reader(self, fd, callback, *args):
    260272
    if reader is not None:
    261273
    reader.cancel()
    262274

    263-
    def remove_reader(self, fd):
    264-
    """Remove a reader callback."""
    275+
    def _remove_reader(self, fd):
    265276
    if self.is_closed():
    266277
    return False
    267278
    try:
    @@ -282,8 +293,7 @@ def remove_reader(self, fd):
    282293
    else:
    283294
    return False
    284295

    285-
    def add_writer(self, fd, callback, *args):
    286-
    """Add a writer callback.."""
    296+
    def _add_writer(self, fd, callback, *args):
    287297
    self._check_closed()
    288298
    handle = events.Handle(callback, args, self)
    289299
    try:
    @@ -298,7 +308,7 @@ def add_writer(self, fd, callback, *args):
    298308
    if writer is not None:
    299309
    writer.cancel()
    300310

    301-
    def remove_writer(self, fd):
    311+
    def _remove_writer(self, fd):
    302312
    """Remove a writer callback."""
    303313
    if self.is_closed():
    304314
    return False
    @@ -321,6 +331,26 @@ def remove_writer(self, fd):
    321331
    else:
    322332
    return False
    323333

    334+
    def add_reader(self, fd, callback, *args):
    335+
    """Add a reader callback."""
    336+
    self._ensure_fd_no_transport(fd)
    337+
    return self._add_reader(fd, callback, *args)
    338+
    339+
    def remove_reader(self, fd):
    340+
    """Remove a reader callback."""
    341+
    self._ensure_fd_no_transport(fd)
    342+
    return self._remove_reader(fd)
    343+
    344+
    def add_writer(self, fd, callback, *args):
    345+
    """Add a writer callback.."""
    346+
    self._ensure_fd_no_transport(fd)
    347+
    return self._add_writer(fd, callback, *args)
    348+
    349+
    def remove_writer(self, fd):
    350+
    """Remove a writer callback."""
    351+
    self._ensure_fd_no_transport(fd)
    352+
    return self._remove_writer(fd)
    353+
    324354
    def sock_recv(self, sock, n):
    325355
    """Receive data from the socket.
    326356
    @@ -494,17 +524,17 @@ def _process_events(self, event_list):
    494524
    fileobj, (reader, writer) = key.fileobj, key.data
    495525
    if mask & selectors.EVENT_READ and reader is not None:
    496526
    if reader._cancelled:
    497-
    self.remove_reader(fileobj)
    527+
    self._remove_reader(fileobj)
    498528
    else:
    499529
    self._add_callback(reader)
    500530
    if mask & selectors.EVENT_WRITE and writer is not None:
    501531
    if writer._cancelled:
    502-
    self.remove_writer(fileobj)
    532+
    self._remove_writer(fileobj)
    503533
    else:
    504534
    self._add_callback(writer)
    505535

    506536
    def _stop_serving(self, sock):
    507-
    self.remove_reader(sock.fileno())
    537+
    self._remove_reader(sock.fileno())
    508538
    sock.close()
    509539

    510540

    @@ -539,6 +569,7 @@ def __init__(self, loop, sock, protocol, extra=None, server=None):
    539569
    self._closing = False # Set when close() called.
    540570
    if self._server is not None:
    541571
    self._server._attach()
    572+
    loop._transports[self._sock_fd] = self
    542573

    543574
    def __repr__(self):
    544575
    info = [self.__class__.__name__]
    @@ -584,10 +615,10 @@ def close(self):
    584615
    if self._closing:
    585616
    return
    586617
    self._closing = True
    587-
    self._loop.remove_reader(self._sock_fd)
    618+
    self._loop._remove_reader(self._sock_fd)
    588619
    if not self._buffer:
    589620
    self._conn_lost += 1
    590-
    self._loop.remove_writer(self._sock_fd)
    621+
    self._loop._remove_writer(self._sock_fd)
    591622
    self._loop.call_soon(self._call_connection_lost, None)
    592623

    593624
    # On Python 3.3 and older, objects with a destructor part of a reference
    @@ -618,10 +649,10 @@ def _force_close(self, exc):
    618649
    return
    619650
    if self._buffer:
    620651
    self._buffer.clear()
    621-
    self._loop.remove_writer(self._sock_fd)
    652+
    self._loop._remove_writer(self._sock_fd)
    622653
    if not self._closing:
    623654
    self._closing = True
    624-
    self._loop.remove_reader(self._sock_fd)
    655+
    self._loop._remove_reader(self._sock_fd)
    625656
    self._conn_lost += 1
    626657
    self._loop.call_soon(self._call_connection_lost, exc)
    627658

    @@ -658,7 +689,7 @@ def __init__(self, loop, sock, protocol, waiter=None,
    658689

    659690
    self._loop.call_soon(self._protocol.connection_made, self)
    660691
    # only start reading when connection_made() has been called
    661-
    self._loop.call_soon(self._loop.add_reader,
    692+
    self._loop.call_soon(self._loop._add_reader,
    662693
    self._sock_fd, self._read_ready)
    663694
    if waiter is not None:
    664695
    # only wake up the waiter when connection_made() has been called
    @@ -671,7 +702,7 @@ def pause_reading(self):
    671702
    if self._paused:
    672703
    raise RuntimeError('Already paused')
    673704
    self._paused = True
    674-
    self._loop.remove_reader(self._sock_fd)
    705+
    self._loop._remove_reader(self._sock_fd)
    675706
    if self._loop.get_debug():
    676707
    logger.debug("%r pauses reading", self)
    677708

    @@ -681,7 +712,7 @@ def resume_reading(self):
    681712
    self._paused = False
    682713
    if self._closing:
    683714
    return
    684-
    self._loop.add_reader(self._sock_fd, self._read_ready)
    715+
    self._loop._add_reader(self._sock_fd, self._read_ready)
    685716
    if self._loop.get_debug():
    686717
    logger.debug("%r resumes reading", self)
    687718

    @@ -705,7 +736,7 @@ def _read_ready(self):
    705736
    # We're keeping the connection open so the
    706737
    # protocol can write more, but we still can't
    707738
    # receive more, so remove the reader callback.
    708-
    self._loop.remove_reader(self._sock_fd)
    739+
    self._loop._remove_reader(self._sock_fd)
    709740
    else:
    710741
    self.close()
    711742

    @@ -738,7 +769,7 @@ def write(self, data):
    738769
    if not data:
    739770
    return
    740771
    # Not all was written; register write handler.
    741-
    self._loop.add_writer(self._sock_fd, self._write_ready)
    772+
    self._loop._add_writer(self._sock_fd, self._write_ready)
    742773

    743774
    # Add it to the buffer.
    744775
    self._buffer.extend(data)
    @@ -754,15 +785,15 @@ def _write_ready(self):
    754785
    except (BlockingIOError, InterruptedError):
    755786
    pass
    756787
    except Exception as exc:
    757-
    self._loop.remove_writer(self._sock_fd)
    788+
    self._loop._remove_writer(self._sock_fd)
    758789
    self._buffer.clear()
    759790
    self._fatal_error(exc, 'Fatal write error on socket transport')
    760791
    else:
    761792
    if n:
    762793
    del self._buffer[:n]
    763794
    self._maybe_resume_protocol() # May append to buffer.
    764795
    if not self._buffer:
    765-
    self._loop.remove_writer(self._sock_fd)
    796+
    self._loop._remove_writer(self._sock_fd)
    766797
    if self._closing:
    767798
    self._call_connection_lost(None)
    768799
    elif self._eof:
    @@ -833,28 +864,28 @@ def _on_handshake(self, start_t 10000 ime):
    833864
    try:
    834865
    self._sock.do_handshake()
    835866
    except ssl.SSLWantReadError:
    836-
    self._loop.add_reader(self._sock_fd,
    837-
    self._on_handshake, start_time)
    867+
    self._loop._add_reader(self._sock_fd,
    868+
    self._on_handshake, start_time)
    838869
    return
    839870
    except ssl.SSLWantWriteError:
    840-
    self._loop.add_writer(self._sock_fd,
    841-
    self._on_handshake, start_time)
    871+
    self._loop._add_writer(self._sock_fd,
    872+
    self._on_handshake, start_time)
    842873
    return
    843874
    except BaseException as exc:
    844875
    if self._loop.get_debug():
    845876
    logger.warning("%r: SSL handshake failed",
    846877
    self, exc_info=True)
    847-
    self._loop.remove_reader(self._sock_fd)
    848-
    self._loop.remove_writer(self._sock_fd)
    878+
    self._loop._remove_reader(self._sock_fd)
    879+
    self._loop._remove_writer(self._sock_fd)
    849880
    self._sock.close()
    850881
    self._wakeup_waiter(exc)
    851882
    if isinstance(exc, Exception):
    852883
    return
    853884
    else:
    854885
    raise
    855886

    856-
    self._loop.remove_reader(self._sock_fd)
    857-
    self._loop.remove_writer(self._sock_fd)
    887+
    self._loop._remove_reader(self._sock_fd)
    888+
    self._loop._remove_writer(self._sock_fd)
    858889

    859890
    peercert = self._sock.getpeercert()
    860891
    if not hasattr(self._sslcontext, 'check_hostname'):
    @@ -882,7 +913,7 @@ def _on_handshake(self, start_time):
    882913

    883914
    self._read_wants_write = False
    884915
    self._write_wants_read = False
    885-
    self._loop.add_reader(self._sock_fd, self._read_ready)
    916+
    self._loop._add_reader(self._sock_fd, self._read_ready)
    886917
    self._protocol_connected = True
    887918
    self._loop.call_soon(self._protocol.connection_made, self)
    888919
    # only wake up the waiter when connection_made() has been called
    @@ -904,7 +935,7 @@ def pause_reading(self):
    904935
    if self._paused:
    905936
    raise RuntimeError('Already paused')
    906937
    self._paused = True
    907-
    self._loop.remove_reader(self._sock_fd)
    938+
    self._loop._remove_reader(self._sock_fd)
    908939
    if self._loop.get_debug():
    909940
    logger.debug("%r pauses reading", self)
    910941

    @@ -914,7 +945,7 @@ def resume_reading(self):
    914945
    self._paused = False
    915946
    if self._closing:
    916947
    return
    917-
    self._loop.add_reader(self._sock_fd, self._read_ready)
    948+
    self._loop._add_reader(self._sock_fd, self._read_ready)
    918949
    if self._loop.get_debug():
    919950
    logger.debug("%r resumes reading", self)
    920951

    @@ -926,16 +957,16 @@ def _read_ready(self):
    926957
    self._write_ready()
    927958

    928959
    if self._buffer:
    929-
    self._loop.add_writer(self._sock_fd, self._write_ready)
    960+
    self._loop._add_writer(self._sock_fd, self._write_ready)
    930961

    931962
    try:
    932963
    data = self._sock.recv(self.max_size)
    933964
    except (BlockingIOError, InterruptedError, ssl.SSLWantReadError):
    934965
    pass
    935966
    except ssl.SSLWantWriteError:
    936967
    self._read_wants_write = True
    937-
    self._loop.remove_reader(self._sock_fd)
    938-
    self._loop.add_writer(self._sock_fd, self._write_ready)
    968+
    self._loop._remove_reader(self._sock_fd)
    969+
    self._loop._add_writer(self._sock_fd, self._write_ready)
    939970
    except Exception as exc:
    940971
    self._fatal_error(exc, 'Fatal read error on SSL transport')
    941972
    else:
    @@ -960,7 +991,7 @@ def _write_ready(self):
    960991
    self._read_ready()
    961992

    962993
    if not (self._paused or self._closing):
    963-
    self._loop.add_reader(self._sock_fd, self._read_ready)
    994+
    self._loop._add_reader(self._sock_fd, self._read_ready)
    964995

    965996
    if self._buffer:
    966997
    try:
    @@ -969,10 +1000,10 @@ def _write_ready(self):
    9691000
    n = 0
    9701001
    except ssl.SSLWantReadError:
    9711002
    n = 0
    972-
    self._loop.remove_writer(self._sock_fd)
    1003+
    self._loop._remove_writer(self._sock_fd)
    9731004
    self._write_wants_read = True
    9741005
    except Exception as exc:
    975-
    self._loop.remove_writer(self._sock_fd)
    1006+
    self._loop._remove_writer(self._sock_fd)
    9761007
    self._buffer.clear()
    9771008
    self._fatal_error(exc, 'Fatal write error on SSL transport')
    9781009
    return
    @@ -983,7 +1014,7 @@ def _write_ready(self):
    9831014
    self._maybe_resume_protocol() # May append to buffer.
    9841015

    9851016
    if not self._buffer:
    986-
    self._loop.remove_writer(self._sock_fd)
    1017+
    self._loop._remove_writer(self._sock_fd)
    9871018
    if self._closing:
    9881019
    self._call_connection_lost(None)
    9891020

    @@ -1001,7 +1032,7 @@ def write(self, data):
    10011032
    return
    10021033

    10031034
    if not self._buffer:
    1004-
    self._loop.add_writer(self._sock_fd, self._write_ready)
    1035+
    self._loop._add_writer(self._sock_fd, self._write_ready)
    10051036

    10061037
    # Add it to the buffer.
    10071038
    self._buffer.extend(data)
    @@ -1021,7 +1052,7 @@ def __init__(self, loop, sock, protocol, address=None,
    10211052
    self._address = address
    10221053
    self._loop.call_soon(self._protocol.connection_made, self)
    10231054
    # only start reading when connection_made() has been called
    1024-
    self._loop.call_soon(self._loop.add_reader,
    1055+
    self._loop.call_soon(self._loop._add_reader,
    10251056
    self._sock_fd, self._read_ready)
    10261057
    if waiter is not None:
    10271058
    # only wake up the waiter when connection_made() has been called
    @@ -1071,7 +1102,7 @@ def sendto(self, data, addr=None):
    10711102
    self._sock.sendto(data, addr)
    10721103
    return
    10731104
    except (BlockingIOError, InterruptedError):
    1074-
    self._loop.add_writer(self._sock_fd, self._sendto_ready)
    1105+
    self._loop._add_writer(self._sock_fd, self._sendto_ready)
    10751106
    except OSError as exc:
    10761107
    self._protocol.error_received(exc)
    10771108
    return
    @@ -1105,6 +1136,6 @@ def _sendto_ready(self):
    11051136

    11061137
    self._maybe_resume_protocol() # May append to buffer.
    11071138
    if not self._buffer:
    1108-
    self._loop.remove_writer(self._sock_fd)
    1139+
    self._loop._remove_writer(self._sock_fd)
    11091140
    if self._closing:
    11101141
    self._call_connection_lost(None)

    0 commit comments

    Comments
     (0)
    0