Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/_websocket/writer.py: 24%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

93 statements  

1"""WebSocket protocol versions 13 and 8.""" 

2 

3import asyncio 

4import random 

5import sys 

6from functools import partial 

7from typing import Final 

8 

9from ..base_protocol import BaseProtocol 

10from ..client_exceptions import ClientConnectionResetError 

11from ..compression_utils import ZLibBackend, ZLibCompressor 

12from ..helpers import DEFAULT_CHUNK_SIZE 

13from .helpers import ( 

14 MASK_LEN, 

15 MSG_SIZE, 

16 PACK_CLOSE_CODE, 

17 PACK_LEN1, 

18 PACK_LEN2, 

19 PACK_LEN3, 

20 PACK_RANDBITS, 

21 websocket_mask, 

22) 

23from .models import WS_DEFLATE_TRAILING, WSMsgType 

24 

25# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames 

26# Control frames (ping, pong, close) are never compressed 

27WS_CONTROL_FRAME_OPCODE: Final[int] = 8 

28 

29# For websockets, keeping latency low is extremely important as implementations 

30# generally expect to be able to send and receive messages quickly. We use a 

31# larger chunk size to reduce the number of executor calls and avoid task 

32# creation overhead, since both are significant sources of latency when chunks 

33# are small. A size of 16KiB was chosen as a balance between avoiding task 

34# overhead and not blocking the event loop too long with synchronous compression. 

35 

36WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024 

37 

38 

39class WebSocketWriter: 

40 """WebSocket writer. 

41 

42 The writer is responsible for sending messages to the client. It is 

43 created by the protocol when a connection is established. The writer 

44 should avoid implementing any application logic and should only be 

45 concerned with the low-level details of the WebSocket protocol. 

46 """ 

47 

48 def __init__( 

49 self, 

50 protocol: BaseProtocol, 

51 transport: asyncio.Transport, 

52 *, 

53 use_mask: bool = False, 

54 limit: int = DEFAULT_CHUNK_SIZE, 

55 random: random.Random = random.Random(), 

56 compress: int = 0, 

57 notakeover: bool = False, 

58 ) -> None: 

59 """Initialize a WebSocket writer.""" 

60 self.protocol = protocol 

61 self.transport = transport 

62 self.use_mask = use_mask 

63 self.get_random_bits = partial(random.getrandbits, 32) 

64 self.compress = compress 

65 self.notakeover = notakeover 

66 self._closing = False 

67 self._limit = limit 

68 self._output_size = 0 

69 self._compressobj: ZLibCompressor | None = None 

70 self._send_lock = asyncio.Lock() 

71 self._background_tasks: set[asyncio.Task[None]] = set() 

72 

73 async def send_frame( 

74 self, message: bytes, opcode: int, compress: int | None = None 

75 ) -> None: 

76 """Send a frame over the websocket with message as its payload.""" 

77 if self._closing and not (opcode & WSMsgType.CLOSE): 

78 raise ClientConnectionResetError("Cannot write to closing transport") 

79 

80 if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE: 

81 # Non-compressed frames don't need lock or shield 

82 self._write_websocket_frame(message, opcode, 0) 

83 elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE: 

84 # Small compressed payloads - compress synchronously in event loop 

85 # We need the lock even though sync compression has no await points. 

86 # This prevents small frames from interleaving with large frames that 

87 # compress in the executor, avoiding compressor state corruption. 

88 async with self._send_lock: 

89 self._send_compressed_frame_sync(message, opcode, compress) 

90 else: 

91 # Large compressed frames need shield to prevent corruption 

92 # For large compressed frames, the entire compress+send 

93 # operation must be atomic. If cancelled after compression but 

94 # before send, the compressor state would be advanced but data 

95 # not sent, corrupting subsequent frames. 

96 # Create a task to shield from cancellation 

97 # The lock is acquired inside the shielded task so the entire 

98 # operation (lock + compress + send) completes atomically. 

99 # Use eager_start on Python 3.12+ to avoid scheduling overhead 

100 loop = asyncio.get_running_loop() 

101 coro = self._send_compressed_frame_async_locked(message, opcode, compress) 

102 if sys.version_info >= (3, 12): 

103 send_task = asyncio.Task(coro, loop=loop, eager_start=True) 

104 else: 

105 send_task = loop.create_task(coro) 

106 # Keep a strong reference to prevent garbage collection 

107 self._background_tasks.add(send_task) 

108 send_task.add_done_callback(self._background_tasks.discard) 

109 await asyncio.shield(send_task) 

110 

111 # It is safe to return control to the event loop when using compression 

112 # after this point as we have already sent or buffered all the data. 

113 # Once we have written output_size up to the limit, we call the 

114 # drain helper which waits for the transport to be ready to accept 

115 # more data. This is a flow control mechanism to prevent the buffer 

116 # from growing too large. The drain helper will return right away 

117 # if the writer is not paused. 

118 if self._output_size > self._limit: 

119 self._output_size = 0 

120 if self.protocol._paused: 

121 await self.protocol._drain_helper() 

122 

123 def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None: 

124 """ 

125 Write a websocket frame to the transport. 

126 

127 This method handles frame header construction, masking, and writing to transport. 

128 It does not handle compression or flow control - those are the responsibility 

129 of the caller. 

130 """ 

131 msg_length = len(message) 

132 

133 use_mask = self.use_mask 

134 mask_bit = 0x80 if use_mask else 0 

135 

136 # Depending on the message length, the header is assembled differently. 

137 # The first byte is reserved for the opcode and the RSV bits. 

138 first_byte = 0x80 | rsv | opcode 

139 if msg_length < 126: 

140 header = PACK_LEN1(first_byte, msg_length | mask_bit) 

141 header_len = 2 

142 elif msg_length < 65536: 

143 header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length) 

144 header_len = 4 

145 else: 

146 header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length) 

147 header_len = 10 

148 

149 if self.transport.is_closing(): 

150 raise ClientConnectionResetError("Cannot write to closing transport") 

151 

152 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3 

153 # If we are using a mask, we need to generate it randomly 

154 # and apply it to the message before sending it. A mask is 

155 # a 32-bit value that is applied to the message using a 

156 # bitwise XOR operation. It is used to prevent certain types 

157 # of attacks on the websocket protocol. The mask is only used 

158 # when aiohttp is acting as a client. Servers do not use a mask. 

159 if use_mask: 

160 mask = PACK_RANDBITS(self.get_random_bits()) 

161 message_arr = bytearray(message) 

162 websocket_mask(mask, message_arr) 

163 self.transport.write(header + mask + message_arr) 

164 self._output_size += MASK_LEN 

165 elif msg_length > MSG_SIZE: 

166 self.transport.write(header) 

167 self.transport.write(message) 

168 else: 

169 self.transport.write(header + message) 

170 

171 self._output_size += header_len + msg_length 

172 

173 def _get_compressor(self, compress: int | None) -> ZLibCompressor: 

174 """Get or create a compressor object for the given compression level.""" 

175 if compress: 

176 # Do not set self._compress if compressing is for this frame 

177 return ZLibCompressor( 

178 level=ZLibBackend.Z_BEST_SPEED, 

179 wbits=-compress, 

180 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, 

181 ) 

182 if not self._compressobj: 

183 self._compressobj = ZLibCompressor( 

184 level=ZLibBackend.Z_BEST_SPEED, 

185 wbits=-self.compress, 

186 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE, 

187 ) 

188 return self._compressobj 

189 

190 def _send_compressed_frame_sync( 

191 self, message: bytes, opcode: int, compress: int | None 

192 ) -> None: 

193 """ 

194 Synchronous send for small compressed frames. 

195 

196 This is used for small compressed payloads that compress synchronously in the event loop. 

197 Since there are no await points, this is inherently cancellation-safe. 

198 """ 

199 # RSV are the reserved bits in the frame header. They are used to 

200 # indicate that the frame is using an extension. 

201 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 

202 compressobj = self._get_compressor(compress) 

203 # (0x40) RSV1 is set for compressed frames 

204 # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 

205 self._write_websocket_frame( 

206 ( 

207 compressobj.compress_sync(message) 

208 + compressobj.flush( 

209 ZLibBackend.Z_FULL_FLUSH 

210 if self.notakeover 

211 else ZLibBackend.Z_SYNC_FLUSH 

212 ) 

213 ).removesuffix(WS_DEFLATE_TRAILING), 

214 opcode, 

215 0x40, 

216 ) 

217 

218 async def _send_compressed_frame_async_locked( 

219 self, message: bytes, opcode: int, compress: int | None 

220 ) -> None: 

221 """ 

222 Async send for large compressed frames with lock. 

223 

224 Acquires the lock and compresses large payloads asynchronously in 

225 the executor. The lock is held for the entire operation to ensure 

226 the compressor state is not corrupted by concurrent sends. 

227 

228 MUST be run shielded from cancellation. If cancelled after 

229 compression but before sending, the compressor state would be 

230 advanced but data not sent, corrupting subsequent frames. 

231 """ 

232 async with self._send_lock: 

233 # RSV are the reserved bits in the frame header. They are used to 

234 # indicate that the frame is using an extension. 

235 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2 

236 compressobj = self._get_compressor(compress) 

237 # (0x40) RSV1 is set for compressed frames 

238 # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1 

239 self._write_websocket_frame( 

240 ( 

241 await compressobj.compress(message) 

242 + compressobj.flush( 

243 ZLibBackend.Z_FULL_FLUSH 

244 if self.notakeover 

245 else ZLibBackend.Z_SYNC_FLUSH 

246 ) 

247 ).removesuffix(WS_DEFLATE_TRAILING), 

248 opcode, 

249 0x40, 

250 ) 

251 

252 async def close(self, code: int = 1000, message: bytes | str = b"") -> None: 

253 """Close the websocket, sending the specified code and message.""" 

254 if isinstance(message, str): 

255 message = message.encode("utf-8") 

256 try: 

257 await self.send_frame( 

258 PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE 

259 ) 

260 finally: 

261 self._closing = True