Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/websocket/_abnf.py: 32%
226 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:34 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:34 +0000
1import array
2import os
3import struct
4import sys
6from threading import Lock
7from typing import Callable, Union
9from ._exceptions import *
10from ._utils import validate_utf8
12"""
13_abnf.py
14websocket - WebSocket client library for Python
16Copyright 2023 engn33r
18Licensed under the Apache License, Version 2.0 (the "License");
19you may not use this file except in compliance with the License.
20You may obtain a copy of the License at
22 http://www.apache.org/licenses/LICENSE-2.0
24Unless required by applicable law or agreed to in writing, software
25distributed under the License is distributed on an "AS IS" BASIS,
26WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
27See the License for the specific language governing permissions and
28limitations under the License.
29"""
31try:
32 # If wsaccel is available, use compiled routines to mask data.
33 # wsaccel only provides around a 10% speed boost compared
34 # to the websocket-client _mask() implementation.
35 # Note that wsaccel is unmaintained.
36 from wsaccel.xormask import XorMaskerSimple
38 def _mask(_m, _d) -> bytes:
39 return XorMaskerSimple(_m).process(_d)
41except ImportError:
42 # wsaccel is not available, use websocket-client _mask()
43 native_byteorder = sys.byteorder
45 def _mask(mask_value: array.array, data_value: array.array) -> bytes:
46 datalen = len(data_value)
47 int_data_value = int.from_bytes(data_value, native_byteorder)
48 int_mask_value = int.from_bytes(mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder)
49 return (int_data_value ^ int_mask_value).to_bytes(datalen, native_byteorder)
52__all__ = [
53 'ABNF', 'continuous_frame', 'frame_buffer',
54 'STATUS_NORMAL',
55 'STATUS_GOING_AWAY',
56 'STATUS_PROTOCOL_ERROR',
57 'STATUS_UNSUPPORTED_DATA_TYPE',
58 'STATUS_STATUS_NOT_AVAILABLE',
59 'STATUS_ABNORMAL_CLOSED',
60 'STATUS_INVALID_PAYLOAD',
61 'STATUS_POLICY_VIOLATION',
62 'STATUS_MESSAGE_TOO_BIG',
63 'STATUS_INVALID_EXTENSION',
64 'STATUS_UNEXPECTED_CONDITION',
65 'STATUS_BAD_GATEWAY',
66 'STATUS_TLS_HANDSHAKE_ERROR',
67]
69# closing frame status codes.
70STATUS_NORMAL = 1000
71STATUS_GOING_AWAY = 1001
72STATUS_PROTOCOL_ERROR = 1002
73STATUS_UNSUPPORTED_DATA_TYPE = 1003
74STATUS_STATUS_NOT_AVAILABLE = 1005
75STATUS_ABNORMAL_CLOSED = 1006
76STATUS_INVALID_PAYLOAD = 1007
77STATUS_POLICY_VIOLATION = 1008
78STATUS_MESSAGE_TOO_BIG = 1009
79STATUS_INVALID_EXTENSION = 1010
80STATUS_UNEXPECTED_CONDITION = 1011
81STATUS_SERVICE_RESTART = 1012
82STATUS_TRY_AGAIN_LATER = 1013
83STATUS_BAD_GATEWAY = 1014
84STATUS_TLS_HANDSHAKE_ERROR = 1015
86VALID_CLOSE_STATUS = (
87 STATUS_NORMAL,
88 STATUS_GOING_AWAY,
89 STATUS_PROTOCOL_ERROR,
90 STATUS_UNSUPPORTED_DATA_TYPE,
91 STATUS_INVALID_PAYLOAD,
92 STATUS_POLICY_VIOLATION,
93 STATUS_MESSAGE_TOO_BIG,
94 STATUS_INVALID_EXTENSION,
95 STATUS_UNEXPECTED_CONDITION,
96 STATUS_SERVICE_RESTART,
97 STATUS_TRY_AGAIN_LATER,
98 STATUS_BAD_GATEWAY,
99)
102class ABNF:
103 """
104 ABNF frame class.
105 See http://tools.ietf.org/html/rfc5234
106 and http://tools.ietf.org/html/rfc6455#section-5.2
107 """
109 # operation code values.
110 OPCODE_CONT = 0x0
111 OPCODE_TEXT = 0x1
112 OPCODE_BINARY = 0x2
113 OPCODE_CLOSE = 0x8
114 OPCODE_PING = 0x9
115 OPCODE_PONG = 0xa
117 # available operation code value tuple
118 OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE,
119 OPCODE_PING, OPCODE_PONG)
121 # opcode human readable string
122 OPCODE_MAP = {
123 OPCODE_CONT: "cont",
124 OPCODE_TEXT: "text",
125 OPCODE_BINARY: "binary",
126 OPCODE_CLOSE: "close",
127 OPCODE_PING: "ping",
128 OPCODE_PONG: "pong"
129 }
131 # data length threshold.
132 LENGTH_7 = 0x7e
133 LENGTH_16 = 1 << 16
134 LENGTH_63 = 1 << 63
136 def __init__(self, fin: int = 0, rsv1: int = 0, rsv2: int = 0, rsv3: int = 0,
137 opcode: int = OPCODE_TEXT, mask: int = 1, data: Union[str, bytes] = "") -> None:
138 """
139 Constructor for ABNF. Please check RFC for arguments.
140 """
141 self.fin = fin
142 self.rsv1 = rsv1
143 self.rsv2 = rsv2
144 self.rsv3 = rsv3
145 self.opcode = opcode
146 self.mask = mask
147 if data is None:
148 data = ""
149 self.data = data
150 self.get_mask_key = os.urandom
152 def validate(self, skip_utf8_validation: bool = False) -> None:
153 """
154 Validate the ABNF frame.
156 Parameters
157 ----------
158 skip_utf8_validation: skip utf8 validation.
159 """
160 if self.rsv1 or self.rsv2 or self.rsv3:
161 raise WebSocketProtocolException("rsv is not implemented, yet")
163 if self.opcode not in ABNF.OPCODES:
164 raise WebSocketProtocolException("Invalid opcode %r", self.opcode)
166 if self.opcode == ABNF.OPCODE_PING and not self.fin:
167 raise WebSocketProtocolException("Invalid ping frame.")
169 if self.opcode == ABNF.OPCODE_CLOSE:
170 l = len(self.data)
171 if not l:
172 return
173 if l == 1 or l >= 126:
174 raise WebSocketProtocolException("Invalid close frame.")
175 if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]):
176 raise WebSocketProtocolException("Invalid close frame.")
178 code = 256 * self.data[0] + self.data[1]
179 if not self._is_valid_close_status(code):
180 raise WebSocketProtocolException("Invalid close opcode %r", code)
182 @staticmethod
183 def _is_valid_close_status(code: int) -> bool:
184 return code in VALID_CLOSE_STATUS or (3000 <= code < 5000)
186 def __str__(self) -> str:
187 return "fin=" + str(self.fin) \
188 + " opcode=" + str(self.opcode) \
189 + " data=" + str(self.data)
191 @staticmethod
192 def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> 'ABNF':
193 """
194 Create frame to send text, binary and other data.
196 Parameters
197 ----------
198 data: str
199 data to send. This is string value(byte array).
200 If opcode is OPCODE_TEXT and this value is unicode,
201 data value is converted into unicode string, automatically.
202 opcode: int
203 operation code. please see OPCODE_MAP.
204 fin: int
205 fin flag. if set to 0, create continue fragmentation.
206 """
207 if opcode == ABNF.OPCODE_TEXT and isinstance(data, str):
208 data = data.encode("utf-8")
209 # mask must be set if send data from client
210 return ABNF(fin, 0, 0, 0, opcode, 1, data)
212 def format(self) -> bytes:
213 """
214 Format this object to string(byte array) to send data to server.
215 """
216 if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]):
217 raise ValueError("not 0 or 1")
218 if self.opcode not in ABNF.OPCODES:
219 raise ValueError("Invalid OPCODE")
220 length = len(self.data)
221 if length >= ABNF.LENGTH_63:
222 raise ValueError("data is too long")
224 frame_header = chr(self.fin << 7 |
225 self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 |
226 self.opcode).encode('latin-1')
227 if length < ABNF.LENGTH_7:
228 frame_header += chr(self.mask << 7 | length).encode('latin-1')
229 elif length < ABNF.LENGTH_16:
230 frame_header += chr(self.mask << 7 | 0x7e).encode('latin-1')
231 frame_header += struct.pack("!H", length)
232 else:
233 frame_header += chr(self.mask << 7 | 0x7f).encode('latin-1')
234 frame_header += struct.pack("!Q", length)
236 if not self.mask:
237 return frame_header + self.data
238 else:
239 mask_key = self.get_mask_key(4)
240 return frame_header + self._get_masked(mask_key)
242 def _get_masked(self, mask_key: Union[str, bytes]) -> bytes:
243 s = ABNF.mask(mask_key, self.data)
245 if isinstance(mask_key, str):
246 mask_key = mask_key.encode('utf-8')
248 return mask_key + s
250 @staticmethod
251 def mask(mask_key: Union[str, bytes], data: Union[str, bytes]) -> bytes:
252 """
253 Mask or unmask data. Just do xor for each byte
255 Parameters
256 ----------
257 mask_key: bytes or str
258 4 byte mask.
259 data: bytes or str
260 data to mask/unmask.
261 """
262 if data is None:
263 data = ""
265 if isinstance(mask_key, str):
266 mask_key = mask_key.encode('latin-1')
268 if isinstance(data, str):
269 data = data.encode('latin-1')
271 return _mask(array.array("B", mask_key), array.array("B", data))
274class frame_buffer:
275 _HEADER_MASK_INDEX = 5
276 _HEADER_LENGTH_INDEX = 6
278 def __init__(self, recv_fn: Callable[[int], int], skip_utf8_validation: bool) -> None:
279 self.recv = recv_fn
280 self.skip_utf8_validation = skip_utf8_validation
281 # Buffers over the packets from the layer beneath until desired amount
282 # bytes of bytes are received.
283 self.recv_buffer = []
284 self.clear()
285 self.lock = Lock()
287 def clear(self) -> None:
288 self.header = None
289 self.length = None
290 self.mask = None
292 def has_received_header(self) -> bool:
293 return self.header is None
295 def recv_header(self) -> None:
296 header = self.recv_strict(2)
297 b1 = header[0]
298 fin = b1 >> 7 & 1
299 rsv1 = b1 >> 6 & 1
300 rsv2 = b1 >> 5 & 1
301 rsv3 = b1 >> 4 & 1
302 opcode = b1 & 0xf
303 b2 = header[1]
304 has_mask = b2 >> 7 & 1
305 length_bits = b2 & 0x7f
307 self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits)
309 def has_mask(self) -> Union[bool, int]:
310 if not self.header:
311 return False
312 return self.header[frame_buffer._HEADER_MASK_INDEX]
314 def has_received_length(self) -> bool:
315 return self.length is None
317 def recv_length(self) -> None:
318 bits = self.header[frame_buffer._HEADER_LENGTH_INDEX]
319 length_bits = bits & 0x7f
320 if length_bits == 0x7e:
321 v = self.recv_strict(2)
322 self.length = struct.unpack("!H", v)[0]
323 elif length_bits == 0x7f:
324 v = self.recv_strict(8)
325 self.length = struct.unpack("!Q", v)[0]
326 else:
327 self.length = length_bits
329 def has_received_mask(self) -> bool:
330 return self.mask is None
332 def recv_mask(self) -> None:
333 self.mask = self.recv_strict(4) if self.has_mask() else ""
335 def recv_frame(self) -> ABNF:
337 with self.lock:
338 # Header
339 if self.has_received_header():
340 self.recv_header()
341 (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header
343 # Frame length
344 if self.has_received_length():
345 self.recv_length()
346 length = self.length
348 # Mask
349 if self.has_received_mask():
350 self.recv_mask()
351 mask = self.mask
353 # Payload
354 payload = self.recv_strict(length)
355 if has_mask:
356 payload = ABNF.mask(mask, payload)
358 # Reset for next frame
359 self.clear()
361 frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload)
362 frame.validate(self.skip_utf8_validation)
364 return frame
366 def recv_strict(self, bufsize: int) -> bytes:
367 shortage = bufsize - sum(map(len, self.recv_buffer))
368 while shortage > 0:
369 # Limit buffer size that we pass to socket.recv() to avoid
370 # fragmenting the heap -- the number of bytes recv() actually
371 # reads is limited by socket buffer and is relatively small,
372 # yet passing large numbers repeatedly causes lots of large
373 # buffers allocated and then shrunk, which results in
374 # fragmentation.
375 bytes_ = self.recv(min(16384, shortage))
376 self.recv_buffer.append(bytes_)
377 shortage -= len(bytes_)
379 unified = b"".join(self.recv_buffer)
381 if shortage == 0:
382 self.recv_buffer = []
383 return unified
384 else:
385 self.recv_buffer = [unified[bufsize:]]
386 return unified[:bufsize]
389class continuous_frame:
391 def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None:
392 self.fire_cont_frame = fire_cont_frame
393 self.skip_utf8_validation = skip_utf8_validation
394 self.cont_data = None
395 self.recving_frames = None
397 def validate(self, frame: ABNF) -> None:
398 if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT:
399 raise WebSocketProtocolException("Illegal frame")
400 if self.recving_frames and \
401 frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
402 raise WebSocketProtocolException("Illegal frame")
404 def add(self, frame: ABNF) -> None:
405 if self.cont_data:
406 self.cont_data[1] += frame.data
407 else:
408 if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY):
409 self.recving_frames = frame.opcode
410 self.cont_data = [frame.opcode, frame.data]
412 if frame.fin:
413 self.recving_frames = None
415 def is_fire(self, frame: ABNF) -> Union[bool, int]:
416 return frame.fin or self.fire_cont_frame
418 def extract(self, frame: ABNF) -> list:
419 data = self.cont_data
420 self.cont_data = None
421 frame.data = data[1]
422 if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data):
423 raise WebSocketPayloadException(
424 "cannot decode: " + repr(frame.data))
426 return [data[0], frame]