1import asyncio
2from typing import Optional, cast
3
4from .helpers import set_exception
5from .tcp_helpers import tcp_nodelay
6
7
8class BaseProtocol(asyncio.Protocol):
9 __slots__ = (
10 "_loop",
11 "_paused",
12 "_drain_waiter",
13 "_connection_lost",
14 "_reading_paused",
15 "transport",
16 )
17
18 def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
19 self._loop: asyncio.AbstractEventLoop = loop
20 self._paused = False
21 self._drain_waiter: Optional[asyncio.Future[None]] = None
22 self._reading_paused = False
23
24 self.transport: Optional[asyncio.Transport] = None
25
26 @property
27 def connected(self) -> bool:
28 """Return True if the connection is open."""
29 return self.transport is not None
30
31 def pause_writing(self) -> None:
32 assert not self._paused
33 self._paused = True
34
35 def resume_writing(self) -> None:
36 assert self._paused
37 self._paused = False
38
39 waiter = self._drain_waiter
40 if waiter is not None:
41 self._drain_waiter = None
42 if not waiter.done():
43 waiter.set_result(None)
44
45 def pause_reading(self) -> None:
46 if not self._reading_paused and self.transport is not None:
47 try:
48 self.transport.pause_reading()
49 except (AttributeError, NotImplementedError, RuntimeError):
50 pass
51 self._reading_paused = True
52
53 def resume_reading(self) -> None:
54 if self._reading_paused and self.transport is not None:
55 try:
56 self.transport.resume_reading()
57 except (AttributeError, NotImplementedError, RuntimeError):
58 pass
59 self._reading_paused = False
60
61 def connection_made(self, transport: asyncio.BaseTransport) -> None:
62 tr = cast(asyncio.Transport, transport)
63 tcp_nodelay(tr, True)
64 self.transport = tr
65
66 def connection_lost(self, exc: Optional[BaseException]) -> None:
67 # Wake up the writer if currently paused.
68 self.transport = None
69 if not self._paused:
70 return
71 waiter = self._drain_waiter
72 if waiter is None:
73 return
74 self._drain_waiter = None
75 if waiter.done():
76 return
77 if exc is None:
78 waiter.set_result(None)
79 else:
80 set_exception(
81 waiter,
82 ConnectionError("Connection lost"),
83 exc,
84 )
85
86 async def _drain_helper(self) -> None:
87 if not self.connected:
88 raise ConnectionResetError("Connection lost")
89 if not self._paused:
90 return
91 waiter = self._drain_waiter
92 if waiter is None:
93 waiter = self._loop.create_future()
94 self._drain_waiter = waiter
95 await asyncio.shield(waiter)