Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/aiohttp/_websocket/writer.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

93 statements  

1"""WebSocket protocol versions 13 and 8.""" 

2 

3import asyncio 

4import random 

5import sys 

6from functools import partial 

7from typing import Final, Optional, Set, Union 

8 

9from ..base_protocol import BaseProtocol 

10from ..client_exceptions import ClientConnectionResetError 

11from ..compression_utils import ZLibBackend, ZLibCompressor 

12from .helpers import ( 

13 MASK_LEN, 

14 MSG_SIZE, 

15 PACK_CLOSE_CODE, 

16 PACK_LEN1, 

17 PACK_LEN2, 

18 PACK_LEN3, 

19 PACK_RANDBITS, 

20 websocket_mask, 

21) 

22from .models import WS_DEFLATE_TRAILING, WSMsgType 

23 

24DEFAULT_LIMIT: Final[int] = 2**16 

25 

26# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames 

27# Control frames (ping, pong, close) are never compressed 

28WS_CONTROL_FRAME_OPCODE: Final[int] = 8 

29 

30# For websockets, keeping latency low is extremely important as implementations 

31# generally expect to be able to send and receive messages quickly. We use a 

32# larger chunk size to reduce the number of executor calls and avoid task 

33# creation overhead, since both are significant sources of latency when chunks 

34# are small. A size of 16KiB was chosen as a balance between avoiding task 

35# overhead and not blocking the event loop too long with synchronous compression. 

36 

37WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024 

38 

39 

40class WebSocketWriter: 

41 """WebSocket writer. 

42 

43 The writer is responsible for sending messages to the client. It is 

44 created by the protocol when a connection is established. The writer 

45 should avoid implementing any application logic and should only be 

46 concerned with the low-level details of the WebSocket protocol. 

47 """ 

48 

49 def __init__( 

50 self, 

51 protocol: BaseProtocol, 

52 transport: asyncio.Transport, 

53 *, 

54 use_mask: bool = False, 

55 limit: int = DEFAULT_LIMIT, 

56 random: random.Random = random.Random(), 

57 compress: int = 0, 

58 notakeover: bool = False, 

59 ) -> None: 

60 """Initialize a WebSocket writer.""" 

61 self.protocol = protocol 

62 self.transport = transport 

63 self.use_mask = use_mask 

64 self.get_random_bits = partial(random.getrandbits, 32) 

65 self.compress = compress 

66 self.notakeover = notakeover 

67 self._closing = False 

68 self._limit = limit 

69 self._output_size = 0 

70 self._compressobj: Optional[ZLibCompressor] = None 

71 self._send_lock = asyncio.Lock() 

72 self._background_tasks: Set[asyncio.Task[None]] = set() 

73 

74 async def send_frame( 

75 self, message: bytes, opcode: int, compress: Optional[int] = None 

76 ) -> None: 

77 """Send a frame over the websocket with message as its payload.""" 

78 if self._closing and not (opcode & WSMsgType.CLOSE): 

79 raise ClientConnectionResetError("Cannot write to closing transport") 

80 

81 if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE: 

82 # Non-compressed frames don't need lock or shield 

83 self._write_websocket_frame(message, opcode, 0) 

84 elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE: 

85 # Small compressed payloads - compress synchronously in event loop 

86 # We need the lock even though sync compression has no await points. 

87 # This prevents small frames from interleaving with large frames that 

88 # compress in the executor, avoiding compressor state corruption. 

89 async with self._send_lock: 

90 self._send_compressed_frame_sync(message, opcode, compress) 

91 else: 

92 # Large compressed frames need shield to prevent corruption 

93 # For large compressed frames, the entire compress+send 

94 # operation must be atomic. If cancelled after compression but 

95 # before send, the compressor state would be advanced but data 

96 # not sent, corrupting subsequent frames. 

97 # Create a task to shield from cancellation 

98 # The lock is acquired inside the shielded task so the entire 

99 # operation (lock + compress + send) completes atomically. 

100 # Use eager_start on Python 3.12+ to avoid scheduling overhead 

101 loop = asyncio.get_running_loop() 

102 coro = self._send_compressed_frame_async_locked(message, opcode, compress) 

103 if sys.version_info >= (3, 12): 

104 send_task = asyncio.Task(coro, loop=loop, eager_start=True) 

105 else: 

106 send_task = loop.create_task(coro) 

107 # Keep a strong reference to prevent garbage collection 

108 self._background_tasks.add(send_task) 

109 send_task.add_done_callback(self._background_tasks.discard) 

110 await asyncio.shield(send_task) 

111 

112 # It is safe to return control to the event loop when using compression 

113 # after this point as we have already sent or buffered all the data. 

114 # Once we have written output_size up to the limit, we call the 

115 # drain helper which waits for the transport to be ready to accept 

116 # more data. This is a flow control mechanism to prevent the buffer 

117 # from growing too large. The drain helper will return right away 

118 # if the writer is not paused. 

119 if self._output_size > self._limit: 

120 self._output_size = 0 

121 if self.protocol._paused: 

122 await self.protocol._drain_helper() 

123 

124 def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None: 

125 """ 

126 Write a websocket frame to the transport. 

127 

128 This method handles frame header construction, masking, and writing to transport. 

129 It does not handle compression or flow control - those are the responsibility 

130 of the caller. 

131 """ 

132 msg_length = len(message) 

133 

134 use_mask = self.use_mask 

135 mask_bit = 0x80 if use_mask else 0 

136 

137 # Depending on the message length, the header is assembled differently. 

138 # The first byte is reserved for the opcode and the RSV bits. 

139 first_byte = 0x80 | rsv | opcode 

140 if msg_length < 126: 

141 header = PACK_LEN1(first_byte, msg_length | mask_bit) 

142 header_len = 2 

143 elif msg_length < 65536: 

144 header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length) 

145 header_len = 4 

146 else: 

147 header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length) 

148 header_len = 10 

149 

150 if self.transport.is_closing(): 

151 raise ClientConnectionResetError("Cannot write to closing transport") 

152 

153 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3 

154 # If we are using a mask, we need to generate it randomly 

155 # and apply it to the message before sending it. A mask is 

156 # a 32-bit value that is applied to the message using a 

157 # bitwise XOR operation. It is used to prevent certain types 

158 # of attacks on the websocket protocol. The mask is only used 

159 # when aiohttp is acting as a client. Servers do not use a mask. 

160 if use_mask: 

161 mask = PACK_RANDBITS(self.get_random_bits()) 

162 message = bytearray(message) 

163 websocket_mask(mask, message) 

164 self.transport.write(header + mask + message) 

165 self._output_size += MASK_LEN 

166 elif msg_length > MSG_SIZE: 

167 self.transport.write(header) 

168 self.transport.write(message) 

169 else: 

170 self.transport.write(header + message) 

171 

172 self._output_size += header_len + msg_length 

173 

174 def _get_compressor(self, compress: Optional[int]) -> ZLibCompressor: 

175 """Get or create a compressor object for the given compression level.""" 

176 if compress: 

177 # Do not set self._compress if compressing is for this frame 

178 return ZLibCompressor( 

179 level=ZLibBackend.Z_BEST_SPEED, 

180 wbits=-compress, 

181 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, 

182 ) 

183 if not self._compressobj: 

184 self._compressobj = ZLibCompressor( 

185 level=ZLibBackend.Z_BEST_SPEED, 

186 wbits=-self.compress, 

187 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, 

188 ) 

189 return self._compressobj 

190 

191 def _send_compressed_frame_sync( 

192 self, message: bytes, opcode: int, compress: Optional[int] 

193 ) -> None: 

194 """ 

195 Synchronous send for small compressed frames. 

196 

197 This is used for small compressed payloads that compress synchronously in the event loop. 

198 Since there are no await points, this is inherently cancellation-safe. 

199 """ 

200 # RSV are the reserved bits in the frame header. They are used to 

201 # indicate that the frame is using an extension. 

202 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 

203 compressobj = self._get_compressor(compress) 

204 # (0x40) RSV1 is set for compressed frames 

205 # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 

206 self._write_websocket_frame( 

207 ( 

208 compressobj.compress_sync(message) 

209 + compressobj.flush( 

210 ZLibBackend.Z_FULL_FLUSH 

211 if self.notakeover 

212 else ZLibBackend.Z_SYNC_FLUSH 

213 ) 

214 ).removesuffix(WS_DEFLATE_TRAILING), 

215 opcode, 

216 0x40, 

217 ) 

218 

219 async def _send_compressed_frame_async_locked( 

220 self, message: bytes, opcode: int, compress: Optional[int] 

221 ) -> None: 

222 """ 

223 Async send for large compressed frames with lock. 

224 

225 Acquires the lock and compresses large payloads asynchronously in 

226 the executor. The lock is held for the entire operation to ensure 

227 the compressor state is not corrupted by concurrent sends. 

228 

229 MUST be run shielded from cancellation. If cancelled after 

230 compression but before sending, the compressor state would be 

231 advanced but data not sent, corrupting subsequent frames. 

232 """ 

233 async with self._send_lock: 

234 # RSV are the reserved bits in the frame header. They are used to 

235 # indicate that the frame is using an extension. 

236 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 

237 compressobj = self._get_compressor(compress) 

238 # (0x40) RSV1 is set for compressed frames 

239 # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 

240 self._write_websocket_frame( 

241 ( 

242 await compressobj.compress(message) 

243 + compressobj.flush( 

244 ZLibBackend.Z_FULL_FLUSH 

245 if self.notakeover 

246 else ZLibBackend.Z_SYNC_FLUSH 

247 ) 

248 ).removesuffix(WS_DEFLATE_TRAILING), 

249 opcode, 

250 0x40, 

251 ) 

252 

253 async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None: 

254 """Close the websocket, sending the specified code and message.""" 

255 if isinstance(message, str): 

256 message = message.encode("utf-8") 

257 try: 

258 await self.send_frame( 

259 PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE 

260 ) 

261 finally: 

262 self._closing = True