Coverage for /pythoncovmergedfiles/medio/medio/src/aiohttp/aiohttp/_websocket/writer.py: 24%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1"""WebSocket protocol versions 13 and 8."""
3import asyncio
4import random
5import sys
6from functools import partial
7from typing import Final
9from ..base_protocol import BaseProtocol
10from ..client_exceptions import ClientConnectionResetError
11from ..compression_utils import ZLibBackend, ZLibCompressor
12from ..helpers import DEFAULT_CHUNK_SIZE
13from .helpers import (
14 MASK_LEN,
15 MSG_SIZE,
16 PACK_CLOSE_CODE,
17 PACK_LEN1,
18 PACK_LEN2,
19 PACK_LEN3,
20 PACK_RANDBITS,
21 websocket_mask,
22)
23from .models import WS_DEFLATE_TRAILING, WSMsgType
25# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames
26# Control frames (ping, pong, close) are never compressed
27WS_CONTROL_FRAME_OPCODE: Final[int] = 8
29# For websockets, keeping latency low is extremely important as implementations
30# generally expect to be able to send and receive messages quickly. We use a
31# larger chunk size to reduce the number of executor calls and avoid task
32# creation overhead, since both are significant sources of latency when chunks
33# are small. A size of 16KiB was chosen as a balance between avoiding task
34# overhead and not blocking the event loop too long with synchronous compression.
36WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024
39class WebSocketWriter:
40 """WebSocket writer.
42 The writer is responsible for sending messages to the client. It is
43 created by the protocol when a connection is established. The writer
44 should avoid implementing any application logic and should only be
45 concerned with the low-level details of the WebSocket protocol.
46 """
48 def __init__(
49 self,
50 protocol: BaseProtocol,
51 transport: asyncio.Transport,
52 *,
53 use_mask: bool = False,
54 limit: int = DEFAULT_CHUNK_SIZE,
55 random: random.Random = random.Random(),
56 compress: int = 0,
57 notakeover: bool = False,
58 ) -> None:
59 """Initialize a WebSocket writer."""
60 self.protocol = protocol
61 self.transport = transport
62 self.use_mask = use_mask
63 self.get_random_bits = partial(random.getrandbits, 32)
64 self.compress = compress
65 self.notakeover = notakeover
66 self._closing = False
67 self._limit = limit
68 self._output_size = 0
69 self._compressobj: ZLibCompressor | None = None
70 self._send_lock = asyncio.Lock()
71 self._background_tasks: set[asyncio.Task[None]] = set()
73 async def send_frame(
74 self, message: bytes, opcode: int, compress: int | None = None
75 ) -> None:
76 """Send a frame over the websocket with message as its payload."""
77 if self._closing and not (opcode & WSMsgType.CLOSE):
78 raise ClientConnectionResetError("Cannot write to closing transport")
80 if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE:
81 # Non-compressed frames don't need lock or shield
82 self._write_websocket_frame(message, opcode, 0)
83 elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE:
84 # Small compressed payloads - compress synchronously in event loop
85 # We need the lock even though sync compression has no await points.
86 # This prevents small frames from interleaving with large frames that
87 # compress in the executor, avoiding compressor state corruption.
88 async with self._send_lock:
89 self._send_compressed_frame_sync(message, opcode, compress)
90 else:
91 # Large compressed frames need shield to prevent corruption
92 # For large compressed frames, the entire compress+send
93 # operation must be atomic. If cancelled after compression but
94 # before send, the compressor state would be advanced but data
95 # not sent, corrupting subsequent frames.
96 # Create a task to shield from cancellation
97 # The lock is acquired inside the shielded task so the entire
98 # operation (lock + compress + send) completes atomically.
99 # Use eager_start on Python 3.12+ to avoid scheduling overhead
100 loop = asyncio.get_running_loop()
101 coro = self._send_compressed_frame_async_locked(message, opcode, compress)
102 if sys.version_info >= (3, 12):
103 send_task = asyncio.Task(coro, loop=loop, eager_start=True)
104 else:
105 send_task = loop.create_task(coro)
106 # Keep a strong reference to prevent garbage collection
107 self._background_tasks.add(send_task)
108 send_task.add_done_callback(self._background_tasks.discard)
109 await asyncio.shield(send_task)
111 # It is safe to return control to the event loop when using compression
112 # after this point as we have already sent or buffered all the data.
113 # Once we have written output_size up to the limit, we call the
114 # drain helper which waits for the transport to be ready to accept
115 # more data. This is a flow control mechanism to prevent the buffer
116 # from growing too large. The drain helper will return right away
117 # if the writer is not paused.
118 if self._output_size > self._limit:
119 self._output_size = 0
120 if self.protocol._paused:
121 await self.protocol._drain_helper()
123 def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None:
124 """
125 Write a websocket frame to the transport.
127 This method handles frame header construction, masking, and writing to transport.
128 It does not handle compression or flow control - those are the responsibility
129 of the caller.
130 """
131 msg_length = len(message)
133 use_mask = self.use_mask
134 mask_bit = 0x80 if use_mask else 0
136 # Depending on the message length, the header is assembled differently.
137 # The first byte is reserved for the opcode and the RSV bits.
138 first_byte = 0x80 | rsv | opcode
139 if msg_length < 126:
140 header = PACK_LEN1(first_byte, msg_length | mask_bit)
141 header_len = 2
142 elif msg_length < 65536:
143 header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
144 header_len = 4
145 else:
146 header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
147 header_len = 10
149 if self.transport.is_closing():
150 raise ClientConnectionResetError("Cannot write to closing transport")
152 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
153 # If we are using a mask, we need to generate it randomly
154 # and apply it to the message before sending it. A mask is
155 # a 32-bit value that is applied to the message using a
156 # bitwise XOR operation. It is used to prevent certain types
157 # of attacks on the websocket protocol. The mask is only used
158 # when aiohttp is acting as a client. Servers do not use a mask.
159 if use_mask:
160 mask = PACK_RANDBITS(self.get_random_bits())
161 message_arr = bytearray(message)
162 websocket_mask(mask, message_arr)
163 self.transport.write(header + mask + message_arr)
164 self._output_size += MASK_LEN
165 elif msg_length > MSG_SIZE:
166 self.transport.write(header)
167 self.transport.write(message)
168 else:
169 self.transport.write(header + message)
171 self._output_size += header_len + msg_length
173 def _get_compressor(self, compress: int | None) -> ZLibCompressor:
174 """Get or create a compressor object for the given compression level."""
175 if compress:
176 # Do not set self._compress if compressing is for this frame
177 return ZLibCompressor(
178 level=ZLibBackend.Z_BEST_SPEED,
179 wbits=-compress,
180 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
181 )
182 if not self._compressobj:
183 self._compressobj = ZLibCompressor(
184 level=ZLibBackend.Z_BEST_SPEED,
185 wbits=-self.compress,
186 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
187 )
188 return self._compressobj
190 def _send_compressed_frame_sync(
191 self, message: bytes, opcode: int, compress: int | None
192 ) -> None:
193 """
194 Synchronous send for small compressed frames.
196 This is used for small compressed payloads that compress synchronously in the event loop.
197 Since there are no await points, this is inherently cancellation-safe.
198 """
199 # RSV are the reserved bits in the frame header. They are used to
200 # indicate that the frame is using an extension.
201 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
202 compressobj = self._get_compressor(compress)
203 # (0x40) RSV1 is set for compressed frames
204 # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
205 self._write_websocket_frame(
206 (
207 compressobj.compress_sync(message)
208 + compressobj.flush(
209 ZLibBackend.Z_FULL_FLUSH
210 if self.notakeover
211 else ZLibBackend.Z_SYNC_FLUSH
212 )
213 ).removesuffix(WS_DEFLATE_TRAILING),
214 opcode,
215 0x40,
216 )
218 async def _send_compressed_frame_async_locked(
219 self, message: bytes, opcode: int, compress: int | None
220 ) -> None:
221 """
222 Async send for large compressed frames with lock.
224 Acquires the lock and compresses large payloads asynchronously in
225 the executor. The lock is held for the entire operation to ensure
226 the compressor state is not corrupted by concurrent sends.
228 MUST be run shielded from cancellation. If cancelled after
229 compression but before sending, the compressor state would be
230 advanced but data not sent, corrupting subsequent frames.
231 """
232 async with self._send_lock:
233 # RSV are the reserved bits in the frame header. They are used to
234 # indicate that the frame is using an extension.
235 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
236 compressobj = self._get_compressor(compress)
237 # (0x40) RSV1 is set for compressed frames
238 # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
239 self._write_websocket_frame(
240 (
241 await compressobj.compress(message)
242 + compressobj.flush(
243 ZLibBackend.Z_FULL_FLUSH
244 if self.notakeover
245 else ZLibBackend.Z_SYNC_FLUSH
246 )
247 ).removesuffix(WS_DEFLATE_TRAILING),
248 opcode,
249 0x40,
250 )
252 async def close(self, code: int = 1000, message: bytes | str = b"") -> None:
253 """Close the websocket, sending the specified code and message."""
254 if isinstance(message, str):
255 message = message.encode("utf-8")
256 try:
257 await self.send_frame(
258 PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
259 )
260 finally:
261 self._closing = True