1"""WebSocket protocol versions 13 and 8."""
2
3import asyncio
4import random
5import sys
6from functools import partial
7from typing import Final, Optional, Set, Union
8
9from ..base_protocol import BaseProtocol
10from ..client_exceptions import ClientConnectionResetError
11from ..compression_utils import ZLibBackend, ZLibCompressor
12from .helpers import (
13 MASK_LEN,
14 MSG_SIZE,
15 PACK_CLOSE_CODE,
16 PACK_LEN1,
17 PACK_LEN2,
18 PACK_LEN3,
19 PACK_RANDBITS,
20 websocket_mask,
21)
22from .models import WS_DEFLATE_TRAILING, WSMsgType
23
24DEFAULT_LIMIT: Final[int] = 2**16
25
26# WebSocket opcode boundary: opcodes 0-7 are data frames, 8-15 are control frames
27# Control frames (ping, pong, close) are never compressed
28WS_CONTROL_FRAME_OPCODE: Final[int] = 8
29
30# For websockets, keeping latency low is extremely important as implementations
31# generally expect to be able to send and receive messages quickly. We use a
32# larger chunk size to reduce the number of executor calls and avoid task
33# creation overhead, since both are significant sources of latency when chunks
34# are small. A size of 16KiB was chosen as a balance between avoiding task
35# overhead and not blocking the event loop too long with synchronous compression.
36
37WEBSOCKET_MAX_SYNC_CHUNK_SIZE = 16 * 1024
38
39
40class WebSocketWriter:
41 """WebSocket writer.
42
43 The writer is responsible for sending messages to the client. It is
44 created by the protocol when a connection is established. The writer
45 should avoid implementing any application logic and should only be
46 concerned with the low-level details of the WebSocket protocol.
47 """
48
49 def __init__(
50 self,
51 protocol: BaseProtocol,
52 transport: asyncio.Transport,
53 *,
54 use_mask: bool = False,
55 limit: int = DEFAULT_LIMIT,
56 random: random.Random = random.Random(),
57 compress: int = 0,
58 notakeover: bool = False,
59 ) -> None:
60 """Initialize a WebSocket writer."""
61 self.protocol = protocol
62 self.transport = transport
63 self.use_mask = use_mask
64 self.get_random_bits = partial(random.getrandbits, 32)
65 self.compress = compress
66 self.notakeover = notakeover
67 self._closing = False
68 self._limit = limit
69 self._output_size = 0
70 self._compressobj: Optional[ZLibCompressor] = None
71 self._send_lock = asyncio.Lock()
72 self._background_tasks: Set[asyncio.Task[None]] = set()
73
74 async def send_frame(
75 self, message: bytes, opcode: int, compress: Optional[int] = None
76 ) -> None:
77 """Send a frame over the websocket with message as its payload."""
78 if self._closing and not (opcode & WSMsgType.CLOSE):
79 raise ClientConnectionResetError("Cannot write to closing transport")
80
81 if not (compress or self.compress) or opcode >= WS_CONTROL_FRAME_OPCODE:
82 # Non-compressed frames don't need lock or shield
83 self._write_websocket_frame(message, opcode, 0)
84 elif len(message) <= WEBSOCKET_MAX_SYNC_CHUNK_SIZE:
85 # Small compressed payloads - compress synchronously in event loop
86 # We need the lock even though sync compression has no await points.
87 # This prevents small frames from interleaving with large frames that
88 # compress in the executor, avoiding compressor state corruption.
89 async with self._send_lock:
90 self._send_compressed_frame_sync(message, opcode, compress)
91 else:
92 # Large compressed frames need shield to prevent corruption
93 # For large compressed frames, the entire compress+send
94 # operation must be atomic. If cancelled after compression but
95 # before send, the compressor state would be advanced but data
96 # not sent, corrupting subsequent frames.
97 # Create a task to shield from cancellation
98 # The lock is acquired inside the shielded task so the entire
99 # operation (lock + compress + send) completes atomically.
100 # Use eager_start on Python 3.12+ to avoid scheduling overhead
101 loop = asyncio.get_running_loop()
102 coro = self._send_compressed_frame_async_locked(message, opcode, compress)
103 if sys.version_info >= (3, 12):
104 send_task = asyncio.Task(coro, loop=loop, eager_start=True)
105 else:
106 send_task = loop.create_task(coro)
107 # Keep a strong reference to prevent garbage collection
108 self._background_tasks.add(send_task)
109 send_task.add_done_callback(self._background_tasks.discard)
110 await asyncio.shield(send_task)
111
112 # It is safe to return control to the event loop when using compression
113 # after this point as we have already sent or buffered all the data.
114 # Once we have written output_size up to the limit, we call the
115 # drain helper which waits for the transport to be ready to accept
116 # more data. This is a flow control mechanism to prevent the buffer
117 # from growing too large. The drain helper will return right away
118 # if the writer is not paused.
119 if self._output_size > self._limit:
120 self._output_size = 0
121 if self.protocol._paused:
122 await self.protocol._drain_helper()
123
124 def _write_websocket_frame(self, message: bytes, opcode: int, rsv: int) -> None:
125 """
126 Write a websocket frame to the transport.
127
128 This method handles frame header construction, masking, and writing to transport.
129 It does not handle compression or flow control - those are the responsibility
130 of the caller.
131 """
132 msg_length = len(message)
133
134 use_mask = self.use_mask
135 mask_bit = 0x80 if use_mask else 0
136
137 # Depending on the message length, the header is assembled differently.
138 # The first byte is reserved for the opcode and the RSV bits.
139 first_byte = 0x80 | rsv | opcode
140 if msg_length < 126:
141 header = PACK_LEN1(first_byte, msg_length | mask_bit)
142 header_len = 2
143 elif msg_length < 65536:
144 header = PACK_LEN2(first_byte, 126 | mask_bit, msg_length)
145 header_len = 4
146 else:
147 header = PACK_LEN3(first_byte, 127 | mask_bit, msg_length)
148 header_len = 10
149
150 if self.transport.is_closing():
151 raise ClientConnectionResetError("Cannot write to closing transport")
152
153 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.3
154 # If we are using a mask, we need to generate it randomly
155 # and apply it to the message before sending it. A mask is
156 # a 32-bit value that is applied to the message using a
157 # bitwise XOR operation. It is used to prevent certain types
158 # of attacks on the websocket protocol. The mask is only used
159 # when aiohttp is acting as a client. Servers do not use a mask.
160 if use_mask:
161 mask = PACK_RANDBITS(self.get_random_bits())
162 message = bytearray(message)
163 websocket_mask(mask, message)
164 self.transport.write(header + mask + message)
165 self._output_size += MASK_LEN
166 elif msg_length > MSG_SIZE:
167 self.transport.write(header)
168 self.transport.write(message)
169 else:
170 self.transport.write(header + message)
171
172 self._output_size += header_len + msg_length
173
174 def _get_compressor(self, compress: Optional[int]) -> ZLibCompressor:
175 """Get or create a compressor object for the given compression level."""
176 if compress:
177 # Do not set self._compress if compressing is for this frame
178 return ZLibCompressor(
179 level=ZLibBackend.Z_BEST_SPEED,
180 wbits=-compress,
181 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
182 )
183 if not self._compressobj:
184 self._compressobj = ZLibCompressor(
185 level=ZLibBackend.Z_BEST_SPEED,
186 wbits=-self.compress,
187 max_sync_chunk_size=WEBSOCKET_MAX_SYNC_CHUNK_SIZE,
188 )
189 return self._compressobj
190
191 def _send_compressed_frame_sync(
192 self, message: bytes, opcode: int, compress: Optional[int]
193 ) -> None:
194 """
195 Synchronous send for small compressed frames.
196
197 This is used for small compressed payloads that compress synchronously in the event loop.
198 Since there are no await points, this is inherently cancellation-safe.
199 """
200 # RSV are the reserved bits in the frame header. They are used to
201 # indicate that the frame is using an extension.
202 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
203 compressobj = self._get_compressor(compress)
204 # (0x40) RSV1 is set for compressed frames
205 # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
206 self._write_websocket_frame(
207 (
208 compressobj.compress_sync(message)
209 + compressobj.flush(
210 ZLibBackend.Z_FULL_FLUSH
211 if self.notakeover
212 else ZLibBackend.Z_SYNC_FLUSH
213 )
214 ).removesuffix(WS_DEFLATE_TRAILING),
215 opcode,
216 0x40,
217 )
218
219 async def _send_compressed_frame_async_locked(
220 self, message: bytes, opcode: int, compress: Optional[int]
221 ) -> None:
222 """
223 Async send for large compressed frames with lock.
224
225 Acquires the lock and compresses large payloads asynchronously in
226 the executor. The lock is held for the entire operation to ensure
227 the compressor state is not corrupted by concurrent sends.
228
229 MUST be run shielded from cancellation. If cancelled after
230 compression but before sending, the compressor state would be
231 advanced but data not sent, corrupting subsequent frames.
232 """
233 async with self._send_lock:
234 # RSV are the reserved bits in the frame header. They are used to
235 # indicate that the frame is using an extension.
236 # https://datatracker.ietf.org/doc/html/rfc6455#section-5.2
237 compressobj = self._get_compressor(compress)
238 # (0x40) RSV1 is set for compressed frames
239 # https://datatracker.ietf.org/doc/html/rfc7692#section-7.2.3.1
240 self._write_websocket_frame(
241 (
242 await compressobj.compress(message)
243 + compressobj.flush(
244 ZLibBackend.Z_FULL_FLUSH
245 if self.notakeover
246 else ZLibBackend.Z_SYNC_FLUSH
247 )
248 ).removesuffix(WS_DEFLATE_TRAILING),
249 opcode,
250 0x40,
251 )
252
253 async def close(self, code: int = 1000, message: Union[bytes, str] = b"") -> None:
254 """Close the websocket, sending the specified code and message."""
255 if isinstance(message, str):
256 message = message.encode("utf-8")
257 try:
258 await self.send_frame(
259 PACK_CLOSE_CODE(code) + message, opcode=WSMsgType.CLOSE
260 )
261 finally:
262 self._closing = True