1import asyncio 
    2from typing import Optional, cast 
    3 
    4from .client_exceptions import ClientConnectionResetError 
    5from .helpers import set_exception 
    6from .tcp_helpers import tcp_nodelay 
    7 
    8 
    9class BaseProtocol(asyncio.Protocol): 
    10    __slots__ = ( 
    11        "_loop", 
    12        "_paused", 
    13        "_drain_waiter", 
    14        "_connection_lost", 
    15        "_reading_paused", 
    16        "transport", 
    17    ) 
    18 
    19    def __init__(self, loop: asyncio.AbstractEventLoop) -> None: 
    20        self._loop: asyncio.AbstractEventLoop = loop 
    21        self._paused = False 
    22        self._drain_waiter: Optional[asyncio.Future[None]] = None 
    23        self._reading_paused = False 
    24 
    25        self.transport: Optional[asyncio.Transport] = None 
    26 
    27    @property 
    28    def connected(self) -> bool: 
    29        """Return True if the connection is open.""" 
    30        return self.transport is not None 
    31 
    32    @property 
    33    def writing_paused(self) -> bool: 
    34        return self._paused 
    35 
    36    def pause_writing(self) -> None: 
    37        assert not self._paused 
    38        self._paused = True 
    39 
    40    def resume_writing(self) -> None: 
    41        assert self._paused 
    42        self._paused = False 
    43 
    44        waiter = self._drain_waiter 
    45        if waiter is not None: 
    46            self._drain_waiter = None 
    47            if not waiter.done(): 
    48                waiter.set_result(None) 
    49 
    50    def pause_reading(self) -> None: 
    51        if not self._reading_paused and self.transport is not None: 
    52            try: 
    53                self.transport.pause_reading() 
    54            except (AttributeError, NotImplementedError, RuntimeError): 
    55                pass 
    56            self._reading_paused = True 
    57 
    58    def resume_reading(self) -> None: 
    59        if self._reading_paused and self.transport is not None: 
    60            try: 
    61                self.transport.resume_reading() 
    62            except (AttributeError, NotImplementedError, RuntimeError): 
    63                pass 
    64            self._reading_paused = False 
    65 
    66    def connection_made(self, transport: asyncio.BaseTransport) -> None: 
    67        tr = cast(asyncio.Transport, transport) 
    68        tcp_nodelay(tr, True) 
    69        self.transport = tr 
    70 
    71    def connection_lost(self, exc: Optional[BaseException]) -> None: 
    72        # Wake up the writer if currently paused. 
    73        self.transport = None 
    74        if not self._paused: 
    75            return 
    76        waiter = self._drain_waiter 
    77        if waiter is None: 
    78            return 
    79        self._drain_waiter = None 
    80        if waiter.done(): 
    81            return 
    82        if exc is None: 
    83            waiter.set_result(None) 
    84        else: 
    85            set_exception( 
    86                waiter, 
    87                ConnectionError("Connection lost"), 
    88                exc, 
    89            ) 
    90 
    91    async def _drain_helper(self) -> None: 
    92        if self.transport is None: 
    93            raise ClientConnectionResetError("Connection lost") 
    94        if not self._paused: 
    95            return 
    96        waiter = self._drain_waiter 
    97        if waiter is None: 
    98            waiter = self._loop.create_future() 
    99            self._drain_waiter = waiter 
    100        await asyncio.shield(waiter)