1import asyncio
2from typing import Optional, cast
3
4from .client_exceptions import ClientConnectionResetError
5from .helpers import set_exception
6from .tcp_helpers import tcp_nodelay
7
8
9class BaseProtocol(asyncio.Protocol):
10 __slots__ = (
11 "_loop",
12 "_paused",
13 "_drain_waiter",
14 "_connection_lost",
15 "_reading_paused",
16 "transport",
17 )
18
19 def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
20 self._loop: asyncio.AbstractEventLoop = loop
21 self._paused = False
22 self._drain_waiter: Optional[asyncio.Future[None]] = None
23 self._reading_paused = False
24
25 self.transport: Optional[asyncio.Transport] = None
26
27 @property
28 def connected(self) -> bool:
29 """Return True if the connection is open."""
30 return self.transport is not None
31
32 @property
33 def writing_paused(self) -> bool:
34 return self._paused
35
36 def pause_writing(self) -> None:
37 assert not self._paused
38 self._paused = True
39
40 def resume_writing(self) -> None:
41 assert self._paused
42 self._paused = False
43
44 waiter = self._drain_waiter
45 if waiter is not None:
46 self._drain_waiter = None
47 if not waiter.done():
48 waiter.set_result(None)
49
50 def pause_reading(self) -> None:
51 if not self._reading_paused and self.transport is not None:
52 try:
53 self.transport.pause_reading()
54 except (AttributeError, NotImplementedError, RuntimeError):
55 pass
56 self._reading_paused = True
57
58 def resume_reading(self) -> None:
59 if self._reading_paused and self.transport is not None:
60 try:
61 self.transport.resume_reading()
62 except (AttributeError, NotImplementedError, RuntimeError):
63 pass
64 self._reading_paused = False
65
66 def connection_made(self, transport: asyncio.BaseTransport) -> None:
67 tr = cast(asyncio.Transport, transport)
68 tcp_nodelay(tr, True)
69 self.transport = tr
70
71 def connection_lost(self, exc: Optional[BaseException]) -> None:
72 # Wake up the writer if currently paused.
73 self.transport = None
74 if not self._paused:
75 return
76 waiter = self._drain_waiter
77 if waiter is None:
78 return
79 self._drain_waiter = None
80 if waiter.done():
81 return
82 if exc is None:
83 waiter.set_result(None)
84 else:
85 set_exception(
86 waiter,
87 ConnectionError("Connection lost"),
88 exc,
89 )
90
91 async def _drain_helper(self) -> None:
92 if self.transport is None:
93 raise ClientConnectionResetError("Connection lost")
94 if not self._paused:
95 return
96 waiter = self._drain_waiter
97 if waiter is None:
98 waiter = self._loop.create_future()
99 self._drain_waiter = waiter
100 await asyncio.shield(waiter)