Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/websocket/_abnf.py: 32%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

231 statements  

1import array 

2import os 

3import struct 

4import sys 

5from threading import Lock 

6from typing import Callable, Optional, Union 

7 

8from ._exceptions import WebSocketPayloadException, WebSocketProtocolException 

9from ._utils import validate_utf8 

10 

11""" 

12_abnf.py 

13websocket - WebSocket client library for Python 

14 

15Copyright 2024 engn33r 

16 

17Licensed under the Apache License, Version 2.0 (the "License"); 

18you may not use this file except in compliance with the License. 

19You may obtain a copy of the License at 

20 

21 http://www.apache.org/licenses/LICENSE-2.0 

22 

23Unless required by applicable law or agreed to in writing, software 

24distributed under the License is distributed on an "AS IS" BASIS, 

25WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

26See the License for the specific language governing permissions and 

27limitations under the License. 

28""" 

29 

30try: 

31 # If wsaccel is available, use compiled routines to mask data. 

32 # wsaccel only provides around a 10% speed boost compared 

33 # to the websocket-client _mask() implementation. 

34 # Note that wsaccel is unmaintained. 

35 from wsaccel.xormask import XorMaskerSimple 

36 

37 def _mask(mask_value: array.array, data_value: array.array) -> bytes: 

38 mask_result: bytes = XorMaskerSimple(mask_value).process(data_value) 

39 return mask_result 

40 

41except ImportError: 

42 # wsaccel is not available, use websocket-client _mask() 

43 native_byteorder = sys.byteorder 

44 

45 def _mask(mask_value: array.array, data_value: array.array) -> bytes: 

46 datalen = len(data_value) 

47 int_data_value = int.from_bytes(data_value, native_byteorder) 

48 int_mask_value = int.from_bytes( 

49 mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder 

50 ) 

51 return (int_data_value ^ int_mask_value).to_bytes(datalen, native_byteorder) 

52 

53 

54__all__ = [ 

55 "ABNF", 

56 "continuous_frame", 

57 "frame_buffer", 

58 "STATUS_NORMAL", 

59 "STATUS_GOING_AWAY", 

60 "STATUS_PROTOCOL_ERROR", 

61 "STATUS_UNSUPPORTED_DATA_TYPE", 

62 "STATUS_STATUS_NOT_AVAILABLE", 

63 "STATUS_ABNORMAL_CLOSED", 

64 "STATUS_INVALID_PAYLOAD", 

65 "STATUS_POLICY_VIOLATION", 

66 "STATUS_MESSAGE_TOO_BIG", 

67 "STATUS_INVALID_EXTENSION", 

68 "STATUS_UNEXPECTED_CONDITION", 

69 "STATUS_BAD_GATEWAY", 

70 "STATUS_TLS_HANDSHAKE_ERROR", 

71] 

72 

73# closing frame status codes. 

74STATUS_NORMAL = 1000 

75STATUS_GOING_AWAY = 1001 

76STATUS_PROTOCOL_ERROR = 1002 

77STATUS_UNSUPPORTED_DATA_TYPE = 1003 

78STATUS_STATUS_NOT_AVAILABLE = 1005 

79STATUS_ABNORMAL_CLOSED = 1006 

80STATUS_INVALID_PAYLOAD = 1007 

81STATUS_POLICY_VIOLATION = 1008 

82STATUS_MESSAGE_TOO_BIG = 1009 

83STATUS_INVALID_EXTENSION = 1010 

84STATUS_UNEXPECTED_CONDITION = 1011 

85STATUS_SERVICE_RESTART = 1012 

86STATUS_TRY_AGAIN_LATER = 1013 

87STATUS_BAD_GATEWAY = 1014 

88STATUS_TLS_HANDSHAKE_ERROR = 1015 

89 

90VALID_CLOSE_STATUS = ( 

91 STATUS_NORMAL, 

92 STATUS_GOING_AWAY, 

93 STATUS_PROTOCOL_ERROR, 

94 STATUS_UNSUPPORTED_DATA_TYPE, 

95 STATUS_INVALID_PAYLOAD, 

96 STATUS_POLICY_VIOLATION, 

97 STATUS_MESSAGE_TOO_BIG, 

98 STATUS_INVALID_EXTENSION, 

99 STATUS_UNEXPECTED_CONDITION, 

100 STATUS_SERVICE_RESTART, 

101 STATUS_TRY_AGAIN_LATER, 

102 STATUS_BAD_GATEWAY, 

103) 

104 

105 

106class ABNF: 

107 """ 

108 ABNF frame class. 

109 See http://tools.ietf.org/html/rfc5234 

110 and http://tools.ietf.org/html/rfc6455#section-5.2 

111 """ 

112 

113 # operation code values. 

114 OPCODE_CONT = 0x0 

115 OPCODE_TEXT = 0x1 

116 OPCODE_BINARY = 0x2 

117 OPCODE_CLOSE = 0x8 

118 OPCODE_PING = 0x9 

119 OPCODE_PONG = 0xA 

120 

121 # available operation code value tuple 

122 OPCODES = ( 

123 OPCODE_CONT, 

124 OPCODE_TEXT, 

125 OPCODE_BINARY, 

126 OPCODE_CLOSE, 

127 OPCODE_PING, 

128 OPCODE_PONG, 

129 ) 

130 

131 # opcode human readable string 

132 OPCODE_MAP = { 

133 OPCODE_CONT: "cont", 

134 OPCODE_TEXT: "text", 

135 OPCODE_BINARY: "binary", 

136 OPCODE_CLOSE: "close", 

137 OPCODE_PING: "ping", 

138 OPCODE_PONG: "pong", 

139 } 

140 

141 # data length threshold. 

142 LENGTH_7 = 0x7E 

143 LENGTH_16 = 1 << 16 

144 LENGTH_63 = 1 << 63 

145 

146 def __init__( 

147 self, 

148 fin: int = 0, 

149 rsv1: int = 0, 

150 rsv2: int = 0, 

151 rsv3: int = 0, 

152 opcode: int = OPCODE_TEXT, 

153 mask_value: int = 1, 

154 data: Union[str, bytes, None] = "", 

155 ) -> None: 

156 """ 

157 Constructor for ABNF. Please check RFC for arguments. 

158 """ 

159 self.fin = fin 

160 self.rsv1 = rsv1 

161 self.rsv2 = rsv2 

162 self.rsv3 = rsv3 

163 self.opcode = opcode 

164 self.mask_value = mask_value 

165 if data is None: 

166 data = "" 

167 self.data = data 

168 self.get_mask_key = os.urandom 

169 

170 def validate(self, skip_utf8_validation: bool = False) -> None: 

171 """ 

172 Validate the ABNF frame. 

173 

174 Parameters 

175 ---------- 

176 skip_utf8_validation: skip utf8 validation. 

177 """ 

178 if self.rsv1 or self.rsv2 or self.rsv3: 

179 raise WebSocketProtocolException("rsv is not implemented, yet") 

180 

181 if self.opcode not in ABNF.OPCODES: 

182 raise WebSocketProtocolException("Invalid opcode %r", self.opcode) 

183 

184 if self.opcode == ABNF.OPCODE_PING and not self.fin: 

185 raise WebSocketProtocolException("Invalid ping frame.") 

186 

187 if self.opcode == ABNF.OPCODE_CLOSE: 

188 l = len(self.data) 

189 if not l: 

190 return 

191 if l == 1 or l >= 126: 

192 raise WebSocketProtocolException("Invalid close frame.") 

193 if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): 

194 raise WebSocketProtocolException("Invalid close frame.") 

195 

196 code = 256 * int(self.data[0]) + int(self.data[1]) 

197 if not self._is_valid_close_status(code): 

198 raise WebSocketProtocolException("Invalid close opcode %r", code) 

199 

200 @staticmethod 

201 def _is_valid_close_status(code: int) -> bool: 

202 return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) 

203 

204 def __str__(self) -> str: 

205 return f"fin={self.fin} opcode={self.opcode} data={self.data}" 

206 

207 @staticmethod 

208 def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> "ABNF": 

209 """ 

210 Create frame to send text, binary and other data. 

211 

212 Parameters 

213 ---------- 

214 data: str 

215 data to send. This is string value(byte array). 

216 If opcode is OPCODE_TEXT and this value is unicode, 

217 data value is converted into unicode string, automatically. 

218 opcode: int 

219 operation code. please see OPCODE_MAP. 

220 fin: int 

221 fin flag. if set to 0, create continue fragmentation. 

222 """ 

223 if opcode == ABNF.OPCODE_TEXT and isinstance(data, str): 

224 data = data.encode("utf-8") 

225 # mask must be set if send data from client 

226 return ABNF(fin, 0, 0, 0, opcode, 1, data) 

227 

228 def format(self) -> bytes: 

229 """ 

230 Format this object to string(byte array) to send data to server. 

231 """ 

232 if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]): 

233 raise ValueError("not 0 or 1") 

234 if self.opcode not in ABNF.OPCODES: 

235 raise ValueError("Invalid OPCODE") 

236 length = len(self.data) 

237 if length >= ABNF.LENGTH_63: 

238 raise ValueError("data is too long") 

239 

240 frame_header = chr( 

241 self.fin << 7 

242 | self.rsv1 << 6 

243 | self.rsv2 << 5 

244 | self.rsv3 << 4 

245 | self.opcode 

246 ).encode("latin-1") 

247 if length < ABNF.LENGTH_7: 

248 frame_header += chr(self.mask_value << 7 | length).encode("latin-1") 

249 elif length < ABNF.LENGTH_16: 

250 frame_header += chr(self.mask_value << 7 | 0x7E).encode("latin-1") 

251 frame_header += struct.pack("!H", length) 

252 else: 

253 frame_header += chr(self.mask_value << 7 | 0x7F).encode("latin-1") 

254 frame_header += struct.pack("!Q", length) 

255 

256 if not self.mask_value: 

257 if isinstance(self.data, str): 

258 self.data = self.data.encode("utf-8") 

259 return frame_header + self.data 

260 mask_key = self.get_mask_key(4) 

261 return frame_header + self._get_masked(mask_key) 

262 

263 def _get_masked(self, mask_key: Union[str, bytes]) -> bytes: 

264 s = ABNF.mask(mask_key, self.data) 

265 

266 if isinstance(mask_key, str): 

267 mask_key = mask_key.encode("utf-8") 

268 

269 return mask_key + s 

270 

271 @staticmethod 

272 def mask(mask_key: Union[str, bytes], data: Union[str, bytes]) -> bytes: 

273 """ 

274 Mask or unmask data. Just do xor for each byte 

275 

276 Parameters 

277 ---------- 

278 mask_key: bytes or str 

279 4 byte mask. 

280 data: bytes or str 

281 data to mask/unmask. 

282 """ 

283 if data is None: 

284 data = "" 

285 

286 if isinstance(mask_key, str): 

287 mask_key = mask_key.encode("latin-1") 

288 

289 if isinstance(data, str): 

290 data = data.encode("latin-1") 

291 

292 return _mask(array.array("B", mask_key), array.array("B", data)) 

293 

294 

295class frame_buffer: 

296 _HEADER_MASK_INDEX = 5 

297 _HEADER_LENGTH_INDEX = 6 

298 

299 def __init__( 

300 self, recv_fn: Callable[[int], int], skip_utf8_validation: bool 

301 ) -> None: 

302 self.recv = recv_fn 

303 self.skip_utf8_validation = skip_utf8_validation 

304 # Buffers over the packets from the layer beneath until desired amount 

305 # bytes of bytes are received. 

306 self.recv_buffer: list = [] 

307 self.clear() 

308 self.lock = Lock() 

309 

310 def clear(self) -> None: 

311 self.header: Optional[tuple] = None 

312 self.length: Optional[int] = None 

313 self.mask_value: Union[bytes, str, None] = None 

314 

315 def has_received_header(self) -> bool: 

316 return self.header is None 

317 

318 def recv_header(self) -> None: 

319 header = self.recv_strict(2) 

320 b1 = header[0] 

321 fin = b1 >> 7 & 1 

322 rsv1 = b1 >> 6 & 1 

323 rsv2 = b1 >> 5 & 1 

324 rsv3 = b1 >> 4 & 1 

325 opcode = b1 & 0xF 

326 b2 = header[1] 

327 has_mask = b2 >> 7 & 1 

328 length_bits = b2 & 0x7F 

329 

330 self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits) 

331 

332 def has_mask(self) -> Union[bool, int]: 

333 if not self.header: 

334 return False 

335 header_val: int = self.header[frame_buffer._HEADER_MASK_INDEX] 

336 return header_val 

337 

338 def has_received_length(self) -> bool: 

339 return self.length is None 

340 

341 def recv_length(self) -> None: 

342 bits = self.header[frame_buffer._HEADER_LENGTH_INDEX] 

343 length_bits = bits & 0x7F 

344 if length_bits == 0x7E: 

345 v = self.recv_strict(2) 

346 self.length = struct.unpack("!H", v)[0] 

347 elif length_bits == 0x7F: 

348 v = self.recv_strict(8) 

349 self.length = struct.unpack("!Q", v)[0] 

350 else: 

351 self.length = length_bits 

352 

353 def has_received_mask(self) -> bool: 

354 return self.mask_value is None 

355 

356 def recv_mask(self) -> None: 

357 self.mask_value = self.recv_strict(4) if self.has_mask() else "" 

358 

359 def recv_frame(self) -> ABNF: 

360 with self.lock: 

361 # Header 

362 if self.has_received_header(): 

363 self.recv_header() 

364 (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header 

365 

366 # Frame length 

367 if self.has_received_length(): 

368 self.recv_length() 

369 length = self.length 

370 

371 # Mask 

372 if self.has_received_mask(): 

373 self.recv_mask() 

374 mask_value = self.mask_value 

375 

376 # Payload 

377 payload = self.recv_strict(length) 

378 if has_mask: 

379 payload = ABNF.mask(mask_value, payload) 

380 

381 # Reset for next frame 

382 self.clear() 

383 

384 frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) 

385 frame.validate(self.skip_utf8_validation) 

386 

387 return frame 

388 

389 def recv_strict(self, bufsize: int) -> bytes: 

390 shortage = bufsize - sum(map(len, self.recv_buffer)) 

391 while shortage > 0: 

392 # Limit buffer size that we pass to socket.recv() to avoid 

393 # fragmenting the heap -- the number of bytes recv() actually 

394 # reads is limited by socket buffer and is relatively small, 

395 # yet passing large numbers repeatedly causes lots of large 

396 # buffers allocated and then shrunk, which results in 

397 # fragmentation. 

398 bytes_ = self.recv(min(16384, shortage)) 

399 self.recv_buffer.append(bytes_) 

400 shortage -= len(bytes_) 

401 

402 unified = b"".join(self.recv_buffer) 

403 

404 if shortage == 0: 

405 self.recv_buffer = [] 

406 return unified 

407 else: 

408 self.recv_buffer = [unified[bufsize:]] 

409 return unified[:bufsize] 

410 

411 

412class continuous_frame: 

413 def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None: 

414 self.fire_cont_frame = fire_cont_frame 

415 self.skip_utf8_validation = skip_utf8_validation 

416 self.cont_data: Optional[list] = None 

417 self.recving_frames: Optional[int] = None 

418 

419 def validate(self, frame: ABNF) -> None: 

420 if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: 

421 raise WebSocketProtocolException("Illegal frame") 

422 if self.recving_frames and frame.opcode in ( 

423 ABNF.OPCODE_TEXT, 

424 ABNF.OPCODE_BINARY, 

425 ): 

426 raise WebSocketProtocolException("Illegal frame") 

427 

428 def add(self, frame: ABNF) -> None: 

429 if self.cont_data: 

430 self.cont_data[1] += frame.data 

431 else: 

432 if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): 

433 self.recving_frames = frame.opcode 

434 self.cont_data = [frame.opcode, frame.data] 

435 

436 if frame.fin: 

437 self.recving_frames = None 

438 

439 def is_fire(self, frame: ABNF) -> Union[bool, int]: 

440 return frame.fin or self.fire_cont_frame 

441 

442 def extract(self, frame: ABNF) -> tuple: 

443 data = self.cont_data 

444 self.cont_data = None 

445 frame.data = data[1] 

446 if ( 

447 not self.fire_cont_frame 

448 and data[0] == ABNF.OPCODE_TEXT 

449 and not self.skip_utf8_validation 

450 and not validate_utf8(frame.data) 

451 ): 

452 raise WebSocketPayloadException(f"cannot decode: {repr(frame.data)}") 

453 return data[0], frame