Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/websocket/_abnf.py: 32%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1import array
2import os
3import struct
4import sys
5from threading import Lock
6from typing import Callable, Optional, Union
8from ._exceptions import WebSocketPayloadException, WebSocketProtocolException
9from ._utils import validate_utf8
11"""
12_abnf.py
13websocket - WebSocket client library for Python
15Copyright 2024 engn33r
17Licensed under the Apache License, Version 2.0 (the "License");
18you may not use this file except in compliance with the License.
19You may obtain a copy of the License at
21 http://www.apache.org/licenses/LICENSE-2.0
23Unless required by applicable law or agreed to in writing, software
24distributed under the License is distributed on an "AS IS" BASIS,
25WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
26See the License for the specific language governing permissions and
27limitations under the License.
28"""
30try:
31 # If wsaccel is available, use compiled routines to mask data.
32 # wsaccel only provides around a 10% speed boost compared
33 # to the websocket-client _mask() implementation.
34 # Note that wsaccel is unmaintained.
35 from wsaccel.xormask import XorMaskerSimple
37 def _mask(mask_value: array.array, data_value: array.array) -> bytes:
38 mask_result: bytes = XorMaskerSimple(mask_value).process(data_value)
39 return mask_result
41except ImportError:
42 # wsaccel is not available, use websocket-client _mask()
43 native_byteorder = sys.byteorder
45 def _mask(mask_value: array.array, data_value: array.array) -> bytes:
46 datalen = len(data_value)
47 int_data_value = int.from_bytes(data_value, native_byteorder)
48 int_mask_value = int.from_bytes(
49 mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder
50 )
51 return (int_data_value ^ int_mask_value).to_bytes(datalen, native_byteorder)
54__all__ = [
55 "ABNF",
56 "continuous_frame",
57 "frame_buffer",
58 "STATUS_NORMAL",
59 "STATUS_GOING_AWAY",
60 "STATUS_PROTOCOL_ERROR",
61 "STATUS_UNSUPPORTED_DATA_TYPE",
62 "STATUS_STATUS_NOT_AVAILABLE",
63 "STATUS_ABNORMAL_CLOSED",
64 "STATUS_INVALID_PAYLOAD",
65 "STATUS_POLICY_VIOLATION",
66 "STATUS_MESSAGE_TOO_BIG",
67 "STATUS_INVALID_EXTENSION",
68 "STATUS_UNEXPECTED_CONDITION",
69 "STATUS_BAD_GATEWAY",
70 "STATUS_TLS_HANDSHAKE_ERROR",
71]
73# closing frame status codes.
74STATUS_NORMAL = 1000
75STATUS_GOING_AWAY = 1001
76STATUS_PROTOCOL_ERROR = 1002
77STATUS_UNSUPPORTED_DATA_TYPE = 1003
78STATUS_STATUS_NOT_AVAILABLE = 1005
79STATUS_ABNORMAL_CLOSED = 1006
80STATUS_INVALID_PAYLOAD = 1007
81STATUS_POLICY_VIOLATION = 1008
82STATUS_MESSAGE_TOO_BIG = 1009
83STATUS_INVALID_EXTENSION = 1010
84STATUS_UNEXPECTED_CONDITION = 1011
85STATUS_SERVICE_RESTART = 1012
86STATUS_TRY_AGAIN_LATER = 1013
87STATUS_BAD_GATEWAY = 1014
88STATUS_TLS_HANDSHAKE_ERROR = 1015
90VALID_CLOSE_STATUS = (
91 STATUS_NORMAL,
92 STATUS_GOING_AWAY,
93 STATUS_PROTOCOL_ERROR,
94 STATUS_UNSUPPORTED_DATA_TYPE,
95 STATUS_INVALID_PAYLOAD,
96 STATUS_POLICY_VIOLATION,
97 STATUS_MESSAGE_TOO_BIG,
98 STATUS_INVALID_EXTENSION,
99 STATUS_UNEXPECTED_CONDITION,
100 STATUS_SERVICE_RESTART,
101 STATUS_TRY_AGAIN_LATER,
102 STATUS_BAD_GATEWAY,
103)
106class ABNF:
107 """
108 ABNF frame class.
109 See http://tools.ietf.org/html/rfc5234
110 and http://tools.ietf.org/html/rfc6455#section-5.2
111 """
113 # operation code values.
114 OPCODE_CONT = 0x0
115 OPCODE_TEXT = 0x1
116 OPCODE_BINARY = 0x2
117 OPCODE_CLOSE = 0x8
118 OPCODE_PING = 0x9
119 OPCODE_PONG = 0xA
121 # available operation code value tuple
122 OPCODES = (
123 OPCODE_CONT,
124 OPCODE_TEXT,
125 OPCODE_BINARY,
126 OPCODE_CLOSE,
127 OPCODE_PING,
128 OPCODE_PONG,
129 )
131 # opcode human readable string
132 OPCODE_MAP = {
133 OPCODE_CONT: "cont",
134 OPCODE_TEXT: "text",
135 OPCODE_BINARY: "binary",
136 OPCODE_CLOSE: "close",
137 OPCODE_PING: "ping",
138 OPCODE_PONG: "pong",
139 }
141 # data length threshold.
142 LENGTH_7 = 0x7E
143 LENGTH_16 = 1 << 16
144 LENGTH_63 = 1 << 63
146 def __init__(
147 self,
148 fin: int = 0,
149 rsv1: int = 0,
150 rsv2: int = 0,
151 rsv3: int = 0,
152 opcode: int = OPCODE_TEXT,
153 mask_value: int = 1,
154 data: Union[str, bytes, None] = "",
155 ) -> None:
156 """
157 Constructor for ABNF. Please check RFC for arguments.
158 """
159 self.fin = fin
160 self.rsv1 = rsv1
161 self.rsv2 = rsv2
162 self.rsv3 = rsv3
163 self.opcode = opcode
164 self.mask_value = mask_value
165 if data is None:
166 data = ""
167 self.data = data
168 self.get_mask_key = os.urandom
170 def validate(self, skip_utf8_validation: bool = False) -> None:
171 """
172 Validate the ABNF frame.
174 Parameters
175 ----------
176 skip_utf8_validation: skip utf8 validation.
177 """
178 if self.rsv1 or self.rsv2 or self.rsv3:
179 raise WebSocketProtocolException("rsv is not implemented, yet")
181 if self.opcode not in ABNF.OPCODES:
182 raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
184 if self.opcode == ABNF.OPCODE_PING and not self.fin:
185 raise WebSocketProtocolException("Invalid ping frame.")
187 if self.opcode == ABNF.OPCODE_CLOSE:
188 l = len(self.data)
189 if not l:
190 return
191 if l == 1 or l >= 126:
192 raise WebSocketProtocolException("Invalid close frame.")
193 if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
194 raise WebSocketProtocolException("Invalid close frame.")
196 code = 256 * int(self.data[0]) + int(self.data[1])
197 if not self._is_valid_close_status(code):
198 raise WebSocketProtocolException("Invalid close opcode %r", code)
200 @staticmethod
201 def _is_valid_close_status(code: int) -> bool:
202 return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
204 def __str__(self) -> str:
205 return f"fin={self.fin} opcode={self.opcode} data={self.data}"
207 @staticmethod
208 def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> "ABNF":
209 """
210 Create frame to send text, binary and other data.
212 Parameters
213 ----------
214 data: str
215 data to send. This is string value(byte array).
216 If opcode is OPCODE_TEXT and this value is unicode,
217 data value is converted into unicode string, automatically.
218 opcode: int
219 operation code. please see OPCODE_MAP.
220 fin: int
221 fin flag. if set to 0, create continue fragmentation.
222 """
223 if opcode == ABNF.OPCODE_TEXT and isinstance(data, str):
224 data = data.encode("utf-8")
225 # mask must be set if send data from client
226 return ABNF(fin, 0, 0, 0, opcode, 1, data)
228 def format(self) -> bytes:
229 """
230 Format this object to string(byte array) to send data to server.
231 """
232 if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
233 raise ValueError("not 0 or 1")
234 if self.opcode not in ABNF.OPCODES:
235 raise ValueError("Invalid OPCODE")
236 length = len(self.data)
237 if length >= ABNF.LENGTH_63:
238 raise ValueError("data is too long")
240 frame_header = chr(
241 self.fin << 7
242 | self.rsv1 << 6
243 | self.rsv2 << 5
244 | self.rsv3 << 4
245 | self.opcode
246 ).encode("latin-1")
247 if length < ABNF.LENGTH_7:
248 frame_header += chr(self.mask_value << 7 | length).encode("latin-1")
249 elif length < ABNF.LENGTH_16:
250 frame_header += chr(self.mask_value << 7 | 0x7E).encode("latin-1")
251 frame_header += struct.pack("!H", length)
252 else:
253 frame_header += chr(self.mask_value << 7 | 0x7F).encode("latin-1")
254 frame_header += struct.pack("!Q", length)
256 if not self.mask_value:
257 if isinstance(self.data, str):
258 self.data = self.data.encode("utf-8")
259 return frame_header + self.data
260 mask_key = self.get_mask_key(4)
261 return frame_header + self._get_masked(mask_key)
263 def _get_masked(self, mask_key: Union[str, bytes]) -> bytes:
264 s = ABNF.mask(mask_key, self.data)
266 if isinstance(mask_key, str):
267 mask_key = mask_key.encode("utf-8")
269 return mask_key + s
271 @staticmethod
272 def mask(mask_key: Union[str, bytes], data: Union[str, bytes]) -> bytes:
273 """
274 Mask or unmask data. Just do xor for each byte
276 Parameters
277 ----------
278 mask_key: bytes or str
279 4 byte mask.
280 data: bytes or str
281 data to mask/unmask.
282 """
283 if data is None:
284 data = ""
286 if isinstance(mask_key, str):
287 mask_key = mask_key.encode("latin-1")
289 if isinstance(data, str):
290 data = data.encode("latin-1")
292 return _mask(array.array("B", mask_key), array.array("B", data))
295class frame_buffer:
296 _HEADER_MASK_INDEX = 5
297 _HEADER_LENGTH_INDEX = 6
299 def __init__(
300 self, recv_fn: Callable[[int], int], skip_utf8_validation: bool
301 ) -> None:
302 self.recv = recv_fn
303 self.skip_utf8_validation = skip_utf8_validation
304 # Buffers over the packets from the layer beneath until desired amount
305 # bytes of bytes are received.
306 self.recv_buffer: list = []
307 self.clear()
308 self.lock = Lock()
310 def clear(self) -> None:
311 self.header: Optional[tuple] = None
312 self.length: Optional[int] = None
313 self.mask_value: Union[bytes, str, None] = None
315 def has_received_header(self) -> bool:
316 return self.header is None
318 def recv_header(self) -> None:
319 header = self.recv_strict(2)
320 b1 = header[0]
321 fin = b1 >> 7 & 1
322 rsv1 = b1 >> 6 & 1
323 rsv2 = b1 >> 5 & 1
324 rsv3 = b1 >> 4 & 1
325 opcode = b1 & 0xF
326 b2 = header[1]
327 has_mask = b2 >> 7 & 1
328 length_bits = b2 & 0x7F
330 self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
332 def has_mask(self) -> Union[bool, int]:
333 if not self.header:
334 return False
335 header_val: int = self.header[frame_buffer._HEADER_MASK_INDEX]
336 return header_val
338 def has_received_length(self) -> bool:
339 return self.length is None
341 def recv_length(self) -> None:
342 bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
343 length_bits = bits & 0x7F
344 if length_bits == 0x7E:
345 v = self.recv_strict(2)
346 self.length = struct.unpack("!H", v)[0]
347 elif length_bits == 0x7F:
348 v = self.recv_strict(8)
349 self.length = struct.unpack("!Q", v)[0]
350 else:
351 self.length = length_bits
353 def has_received_mask(self) -> bool:
354 return self.mask_value is None
356 def recv_mask(self) -> None:
357 self.mask_value = self.recv_strict(4) if self.has_mask() else ""
359 def recv_frame(self) -> ABNF:
360 with self.lock:
361 # Header
362 if self.has_received_header():
363 self.recv_header()
364 (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
366 # Frame length
367 if self.has_received_length():
368 self.recv_length()
369 length = self.length
371 # Mask
372 if self.has_received_mask():
373 self.recv_mask()
374 mask_value = self.mask_value
376 # Payload
377 payload = self.recv_strict(length)
378 if has_mask:
379 payload = ABNF.mask(mask_value, payload)
381 # Reset for next frame
382 self.clear()
384 frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
385 frame.validate(self.skip_utf8_validation)
387 return frame
389 def recv_strict(self, bufsize: int) -> bytes:
390 shortage = bufsize - sum(map(len, self.recv_buffer))
391 while shortage > 0:
392 # Limit buffer size that we pass to socket.recv() to avoid
393 # fragmenting the heap -- the number of bytes recv() actually
394 # reads is limited by socket buffer and is relatively small,
395 # yet passing large numbers repeatedly causes lots of large
396 # buffers allocated and then shrunk, which results in
397 # fragmentation.
398 bytes_ = self.recv(min(16384, shortage))
399 self.recv_buffer.append(bytes_)
400 shortage -= len(bytes_)
402 unified = b"".join(self.recv_buffer)
404 if shortage == 0:
405 self.recv_buffer = []
406 return unified
407 else:
408 self.recv_buffer = [unified[bufsize:]]
409 return unified[:bufsize]
412class continuous_frame:
413 def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None:
414 self.fire_cont_frame = fire_cont_frame
415 self.skip_utf8_validation = skip_utf8_validation
416 self.cont_data: Optional[list] = None
417 self.recving_frames: Optional[int] = None
419 def validate(self, frame: ABNF) -> None:
420 if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
421 raise WebSocketProtocolException("Illegal frame")
422 if self.recving_frames and frame.opcode in (
423 ABNF.OPCODE_TEXT,
424 ABNF.OPCODE_BINARY,
425 ):
426 raise WebSocketProtocolException("Illegal frame")
428 def add(self, frame: ABNF) -> None:
429 if self.cont_data:
430 self.cont_data[1] += frame.data
431 else:
432 if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
433 self.recving_frames = frame.opcode
434 self.cont_data = [frame.opcode, frame.data]
436 if frame.fin:
437 self.recving_frames = None
439 def is_fire(self, frame: ABNF) -> Union[bool, int]:
440 return frame.fin or self.fire_cont_frame
442 def extract(self, frame: ABNF) -> tuple:
443 data = self.cont_data
444 self.cont_data = None
445 frame.data = data[1]
446 if (
447 not self.fire_cont_frame
448 and data[0] == ABNF.OPCODE_TEXT
449 and not self.skip_utf8_validation
450 and not validate_utf8(frame.data)
451 ):
452 raise WebSocketPayloadException(f"cannot decode: {repr(frame.data)}")
453 return data[0], frame