1from __future__ import annotations
2
3import dataclasses
4import enum
5import io
6import os
7import secrets
8import struct
9from collections.abc import Generator, Sequence
10from typing import Callable, Union
11
12from .exceptions import PayloadTooBig, ProtocolError
13
14
15try:
16 from .speedups import apply_mask
17except ImportError:
18 from .utils import apply_mask
19
20
21__all__ = [
22 "Opcode",
23 "OP_CONT",
24 "OP_TEXT",
25 "OP_BINARY",
26 "OP_CLOSE",
27 "OP_PING",
28 "OP_PONG",
29 "DATA_OPCODES",
30 "CTRL_OPCODES",
31 "CloseCode",
32 "Frame",
33 "Close",
34]
35
36
37class Opcode(enum.IntEnum):
38 """Opcode values for WebSocket frames."""
39
40 CONT, TEXT, BINARY = 0x00, 0x01, 0x02
41 CLOSE, PING, PONG = 0x08, 0x09, 0x0A
42
43
44OP_CONT = Opcode.CONT
45OP_TEXT = Opcode.TEXT
46OP_BINARY = Opcode.BINARY
47OP_CLOSE = Opcode.CLOSE
48OP_PING = Opcode.PING
49OP_PONG = Opcode.PONG
50
51DATA_OPCODES = OP_CONT, OP_TEXT, OP_BINARY
52CTRL_OPCODES = OP_CLOSE, OP_PING, OP_PONG
53
54
55class CloseCode(enum.IntEnum):
56 """Close code values for WebSocket close frames."""
57
58 NORMAL_CLOSURE = 1000
59 GOING_AWAY = 1001
60 PROTOCOL_ERROR = 1002
61 UNSUPPORTED_DATA = 1003
62 # 1004 is reserved
63 NO_STATUS_RCVD = 1005
64 ABNORMAL_CLOSURE = 1006
65 INVALID_DATA = 1007
66 POLICY_VIOLATION = 1008
67 MESSAGE_TOO_BIG = 1009
68 MANDATORY_EXTENSION = 1010
69 INTERNAL_ERROR = 1011
70 SERVICE_RESTART = 1012
71 TRY_AGAIN_LATER = 1013
72 BAD_GATEWAY = 1014
73 TLS_HANDSHAKE = 1015
74
75
76# See https://www.iana.org/assignments/websocket/websocket.xhtml
77CLOSE_CODE_EXPLANATIONS: dict[int, str] = {
78 CloseCode.NORMAL_CLOSURE: "OK",
79 CloseCode.GOING_AWAY: "going away",
80 CloseCode.PROTOCOL_ERROR: "protocol error",
81 CloseCode.UNSUPPORTED_DATA: "unsupported data",
82 CloseCode.NO_STATUS_RCVD: "no status received [internal]",
83 CloseCode.ABNORMAL_CLOSURE: "abnormal closure [internal]",
84 CloseCode.INVALID_DATA: "invalid frame payload data",
85 CloseCode.POLICY_VIOLATION: "policy violation",
86 CloseCode.MESSAGE_TOO_BIG: "message too big",
87 CloseCode.MANDATORY_EXTENSION: "mandatory extension",
88 CloseCode.INTERNAL_ERROR: "internal error",
89 CloseCode.SERVICE_RESTART: "service restart",
90 CloseCode.TRY_AGAIN_LATER: "try again later",
91 CloseCode.BAD_GATEWAY: "bad gateway",
92 CloseCode.TLS_HANDSHAKE: "TLS handshake failure [internal]",
93}
94
95
96# Close code that are allowed in a close frame.
97# Using a set optimizes `code in EXTERNAL_CLOSE_CODES`.
98EXTERNAL_CLOSE_CODES = {
99 CloseCode.NORMAL_CLOSURE,
100 CloseCode.GOING_AWAY,
101 CloseCode.PROTOCOL_ERROR,
102 CloseCode.UNSUPPORTED_DATA,
103 CloseCode.INVALID_DATA,
104 CloseCode.POLICY_VIOLATION,
105 CloseCode.MESSAGE_TOO_BIG,
106 CloseCode.MANDATORY_EXTENSION,
107 CloseCode.INTERNAL_ERROR,
108 CloseCode.SERVICE_RESTART,
109 CloseCode.TRY_AGAIN_LATER,
110 CloseCode.BAD_GATEWAY,
111}
112
113
114OK_CLOSE_CODES = {
115 CloseCode.NORMAL_CLOSURE,
116 CloseCode.GOING_AWAY,
117 CloseCode.NO_STATUS_RCVD,
118}
119
120
121BytesLike = bytes, bytearray, memoryview
122
123
124@dataclasses.dataclass
125class Frame:
126 """
127 WebSocket frame.
128
129 Attributes:
130 opcode: Opcode.
131 data: Payload data.
132 fin: FIN bit.
133 rsv1: RSV1 bit.
134 rsv2: RSV2 bit.
135 rsv3: RSV3 bit.
136
137 Only these fields are needed. The MASK bit, payload length and masking-key
138 are handled on the fly when parsing and serializing frames.
139
140 """
141
142 opcode: Opcode
143 data: Union[bytes, bytearray, memoryview]
144 fin: bool = True
145 rsv1: bool = False
146 rsv2: bool = False
147 rsv3: bool = False
148
149 # Configure if you want to see more in logs. Should be a multiple of 3.
150 MAX_LOG_SIZE = int(os.environ.get("WEBSOCKETS_MAX_LOG_SIZE", "75"))
151
152 def __str__(self) -> str:
153 """
154 Return a human-readable representation of a frame.
155
156 """
157 coding = None
158 length = f"{len(self.data)} byte{'' if len(self.data) == 1 else 's'}"
159 non_final = "" if self.fin else "continued"
160
161 if self.opcode is OP_TEXT:
162 # Decoding only the beginning and the end is needlessly hard.
163 # Decode the entire payload then elide later if necessary.
164 data = repr(bytes(self.data).decode())
165 elif self.opcode is OP_BINARY:
166 # We'll show at most the first 16 bytes and the last 8 bytes.
167 # Encode just what we need, plus two dummy bytes to elide later.
168 binary = self.data
169 if len(binary) > self.MAX_LOG_SIZE // 3:
170 cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8
171 binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]])
172 data = " ".join(f"{byte:02x}" for byte in binary)
173 elif self.opcode is OP_CLOSE:
174 data = str(Close.parse(self.data))
175 elif self.data:
176 # We don't know if a Continuation frame contains text or binary.
177 # Ping and Pong frames could contain UTF-8.
178 # Attempt to decode as UTF-8 and display it as text; fallback to
179 # binary. If self.data is a memoryview, it has no decode() method,
180 # which raises AttributeError.
181 try:
182 data = repr(bytes(self.data).decode())
183 coding = "text"
184 except (UnicodeDecodeError, AttributeError):
185 binary = self.data
186 if len(binary) > self.MAX_LOG_SIZE // 3:
187 cut = (self.MAX_LOG_SIZE // 3 - 1) // 3 # by default cut = 8
188 binary = b"".join([binary[: 2 * cut], b"\x00\x00", binary[-cut:]])
189 data = " ".join(f"{byte:02x}" for byte in binary)
190 coding = "binary"
191 else:
192 data = "''"
193
194 if len(data) > self.MAX_LOG_SIZE:
195 cut = self.MAX_LOG_SIZE // 3 - 1 # by default cut = 24
196 data = data[: 2 * cut] + "..." + data[-cut:]
197
198 metadata = ", ".join(filter(None, [coding, length, non_final]))
199
200 return f"{self.opcode.name} {data} [{metadata}]"
201
202 @classmethod
203 def parse(
204 cls,
205 read_exact: Callable[[int], Generator[None, None, bytes]],
206 *,
207 mask: bool,
208 max_size: int | None = None,
209 extensions: Sequence[extensions.Extension] | None = None,
210 ) -> Generator[None, None, Frame]:
211 """
212 Parse a WebSocket frame.
213
214 This is a generator-based coroutine.
215
216 Args:
217 read_exact: Generator-based coroutine that reads the requested
218 bytes or raises an exception if there isn't enough data.
219 mask: Whether the frame should be masked i.e. whether the read
220 happens on the server side.
221 max_size: Maximum payload size in bytes.
222 extensions: List of extensions, applied in reverse order.
223
224 Raises:
225 EOFError: If the connection is closed without a full WebSocket frame.
226 PayloadTooBig: If the frame's payload size exceeds ``max_size``.
227 ProtocolError: If the frame contains incorrect values.
228
229 """
230 # Read the header.
231 data = yield from read_exact(2)
232 head1, head2 = struct.unpack("!BB", data)
233
234 # While not Pythonic, this is marginally faster than calling bool().
235 fin = True if head1 & 0b10000000 else False
236 rsv1 = True if head1 & 0b01000000 else False
237 rsv2 = True if head1 & 0b00100000 else False
238 rsv3 = True if head1 & 0b00010000 else False
239
240 try:
241 opcode = Opcode(head1 & 0b00001111)
242 except ValueError as exc:
243 raise ProtocolError("invalid opcode") from exc
244
245 if (True if head2 & 0b10000000 else False) != mask:
246 raise ProtocolError("incorrect masking")
247
248 length = head2 & 0b01111111
249 if length == 126:
250 data = yield from read_exact(2)
251 (length,) = struct.unpack("!H", data)
252 elif length == 127:
253 data = yield from read_exact(8)
254 (length,) = struct.unpack("!Q", data)
255 if max_size is not None and length > max_size:
256 raise PayloadTooBig(length, max_size)
257 if mask:
258 mask_bytes = yield from read_exact(4)
259
260 # Read the data.
261 data = yield from read_exact(length)
262 if mask:
263 data = apply_mask(data, mask_bytes)
264
265 frame = cls(opcode, data, fin, rsv1, rsv2, rsv3)
266
267 if extensions is None:
268 extensions = []
269 for extension in reversed(extensions):
270 frame = extension.decode(frame, max_size=max_size)
271
272 frame.check()
273
274 return frame
275
276 def serialize(
277 self,
278 *,
279 mask: bool,
280 extensions: Sequence[extensions.Extension] | None = None,
281 ) -> bytes:
282 """
283 Serialize a WebSocket frame.
284
285 Args:
286 mask: Whether the frame should be masked i.e. whether the write
287 happens on the client side.
288 extensions: List of extensions, applied in order.
289
290 Raises:
291 ProtocolError: If the frame contains incorrect values.
292
293 """
294 self.check()
295
296 if extensions is None:
297 extensions = []
298 for extension in extensions:
299 self = extension.encode(self)
300
301 output = io.BytesIO()
302
303 # Prepare the header.
304 head1 = (
305 (0b10000000 if self.fin else 0)
306 | (0b01000000 if self.rsv1 else 0)
307 | (0b00100000 if self.rsv2 else 0)
308 | (0b00010000 if self.rsv3 else 0)
309 | self.opcode
310 )
311
312 head2 = 0b10000000 if mask else 0
313
314 length = len(self.data)
315 if length < 126:
316 output.write(struct.pack("!BB", head1, head2 | length))
317 elif length < 65536:
318 output.write(struct.pack("!BBH", head1, head2 | 126, length))
319 else:
320 output.write(struct.pack("!BBQ", head1, head2 | 127, length))
321
322 if mask:
323 mask_bytes = secrets.token_bytes(4)
324 output.write(mask_bytes)
325
326 # Prepare the data.
327 if mask:
328 data = apply_mask(self.data, mask_bytes)
329 else:
330 data = self.data
331 output.write(data)
332
333 return output.getvalue()
334
335 def check(self) -> None:
336 """
337 Check that reserved bits and opcode have acceptable values.
338
339 Raises:
340 ProtocolError: If a reserved bit or the opcode is invalid.
341
342 """
343 if self.rsv1 or self.rsv2 or self.rsv3:
344 raise ProtocolError("reserved bits must be 0")
345
346 if self.opcode in CTRL_OPCODES:
347 if len(self.data) > 125:
348 raise ProtocolError("control frame too long")
349 if not self.fin:
350 raise ProtocolError("fragmented control frame")
351
352
353@dataclasses.dataclass
354class Close:
355 """
356 Code and reason for WebSocket close frames.
357
358 Attributes:
359 code: Close code.
360 reason: Close reason.
361
362 """
363
364 code: int
365 reason: str
366
367 def __str__(self) -> str:
368 """
369 Return a human-readable representation of a close code and reason.
370
371 """
372 if 3000 <= self.code < 4000:
373 explanation = "registered"
374 elif 4000 <= self.code < 5000:
375 explanation = "private use"
376 else:
377 explanation = CLOSE_CODE_EXPLANATIONS.get(self.code, "unknown")
378 result = f"{self.code} ({explanation})"
379
380 if self.reason:
381 result = f"{result} {self.reason}"
382
383 return result
384
385 @classmethod
386 def parse(cls, data: bytes) -> Close:
387 """
388 Parse the payload of a close frame.
389
390 Args:
391 data: Payload of the close frame.
392
393 Raises:
394 ProtocolError: If data is ill-formed.
395 UnicodeDecodeError: If the reason isn't valid UTF-8.
396
397 """
398 if len(data) >= 2:
399 (code,) = struct.unpack("!H", data[:2])
400 reason = data[2:].decode()
401 close = cls(code, reason)
402 close.check()
403 return close
404 elif len(data) == 0:
405 return cls(CloseCode.NO_STATUS_RCVD, "")
406 else:
407 raise ProtocolError("close frame too short")
408
409 def serialize(self) -> bytes:
410 """
411 Serialize the payload of a close frame.
412
413 """
414 self.check()
415 return struct.pack("!H", self.code) + self.reason.encode()
416
417 def check(self) -> None:
418 """
419 Check that the close code has a valid value for a close frame.
420
421 Raises:
422 ProtocolError: If the close code is invalid.
423
424 """
425 if not (self.code in EXTERNAL_CLOSE_CODES or 3000 <= self.code < 5000):
426 raise ProtocolError("invalid status code")
427
428
429# At the bottom to break import cycles created by type annotations.
430from . import extensions # noqa: E402