Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/websocket/_abnf.py: 32%

226 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:34 +0000

1import array 

2import os 

3import struct 

4import sys 

5 

6from threading import Lock 

7from typing import Callable, Union 

8 

9from ._exceptions import * 

10from ._utils import validate_utf8 

11 

12""" 

13_abnf.py 

14websocket - WebSocket client library for Python 

15 

16Copyright 2023 engn33r 

17 

18Licensed under the Apache License, Version 2.0 (the "License"); 

19you may not use this file except in compliance with the License. 

20You may obtain a copy of the License at 

21 

22 http://www.apache.org/licenses/LICENSE-2.0 

23 

24Unless required by applicable law or agreed to in writing, software 

25distributed under the License is distributed on an "AS IS" BASIS, 

26WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

27See the License for the specific language governing permissions and 

28limitations under the License. 

29""" 

30 

31try: 

32 # If wsaccel is available, use compiled routines to mask data. 

33 # wsaccel only provides around a 10% speed boost compared 

34 # to the websocket-client _mask() implementation. 

35 # Note that wsaccel is unmaintained. 

36 from wsaccel.xormask import XorMaskerSimple 

37 

38 def _mask(_m, _d) -> bytes: 

39 return XorMaskerSimple(_m).process(_d) 

40 

41except ImportError: 

42 # wsaccel is not available, use websocket-client _mask() 

43 native_byteorder = sys.byteorder 

44 

45 def _mask(mask_value: array.array, data_value: array.array) -> bytes: 

46 datalen = len(data_value) 

47 int_data_value = int.from_bytes(data_value, native_byteorder) 

48 int_mask_value = int.from_bytes(mask_value * (datalen // 4) + mask_value[: datalen % 4], native_byteorder) 

49 return (int_data_value ^ int_mask_value).to_bytes(datalen, native_byteorder) 

50 

51 

52__all__ = [ 

53 'ABNF', 'continuous_frame', 'frame_buffer', 

54 'STATUS_NORMAL', 

55 'STATUS_GOING_AWAY', 

56 'STATUS_PROTOCOL_ERROR', 

57 'STATUS_UNSUPPORTED_DATA_TYPE', 

58 'STATUS_STATUS_NOT_AVAILABLE', 

59 'STATUS_ABNORMAL_CLOSED', 

60 'STATUS_INVALID_PAYLOAD', 

61 'STATUS_POLICY_VIOLATION', 

62 'STATUS_MESSAGE_TOO_BIG', 

63 'STATUS_INVALID_EXTENSION', 

64 'STATUS_UNEXPECTED_CONDITION', 

65 'STATUS_BAD_GATEWAY', 

66 'STATUS_TLS_HANDSHAKE_ERROR', 

67] 

68 

69# closing frame status codes. 

70STATUS_NORMAL = 1000 

71STATUS_GOING_AWAY = 1001 

72STATUS_PROTOCOL_ERROR = 1002 

73STATUS_UNSUPPORTED_DATA_TYPE = 1003 

74STATUS_STATUS_NOT_AVAILABLE = 1005 

75STATUS_ABNORMAL_CLOSED = 1006 

76STATUS_INVALID_PAYLOAD = 1007 

77STATUS_POLICY_VIOLATION = 1008 

78STATUS_MESSAGE_TOO_BIG = 1009 

79STATUS_INVALID_EXTENSION = 1010 

80STATUS_UNEXPECTED_CONDITION = 1011 

81STATUS_SERVICE_RESTART = 1012 

82STATUS_TRY_AGAIN_LATER = 1013 

83STATUS_BAD_GATEWAY = 1014 

84STATUS_TLS_HANDSHAKE_ERROR = 1015 

85 

86VALID_CLOSE_STATUS = ( 

87 STATUS_NORMAL, 

88 STATUS_GOING_AWAY, 

89 STATUS_PROTOCOL_ERROR, 

90 STATUS_UNSUPPORTED_DATA_TYPE, 

91 STATUS_INVALID_PAYLOAD, 

92 STATUS_POLICY_VIOLATION, 

93 STATUS_MESSAGE_TOO_BIG, 

94 STATUS_INVALID_EXTENSION, 

95 STATUS_UNEXPECTED_CONDITION, 

96 STATUS_SERVICE_RESTART, 

97 STATUS_TRY_AGAIN_LATER, 

98 STATUS_BAD_GATEWAY, 

99) 

100 

101 

102class ABNF: 

103 """ 

104 ABNF frame class. 

105 See http://tools.ietf.org/html/rfc5234 

106 and http://tools.ietf.org/html/rfc6455#section-5.2 

107 """ 

108 

109 # operation code values. 

110 OPCODE_CONT = 0x0 

111 OPCODE_TEXT = 0x1 

112 OPCODE_BINARY = 0x2 

113 OPCODE_CLOSE = 0x8 

114 OPCODE_PING = 0x9 

115 OPCODE_PONG = 0xa 

116 

117 # available operation code value tuple 

118 OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE, 

119 OPCODE_PING, OPCODE_PONG) 

120 

121 # opcode human readable string 

122 OPCODE_MAP = { 

123 OPCODE_CONT: "cont", 

124 OPCODE_TEXT: "text", 

125 OPCODE_BINARY: "binary", 

126 OPCODE_CLOSE: "close", 

127 OPCODE_PING: "ping", 

128 OPCODE_PONG: "pong" 

129 } 

130 

131 # data length threshold. 

132 LENGTH_7 = 0x7e 

133 LENGTH_16 = 1 << 16 

134 LENGTH_63 = 1 << 63 

135 

136 def __init__(self, fin: int = 0, rsv1: int = 0, rsv2: int = 0, rsv3: int = 0, 

137 opcode: int = OPCODE_TEXT, mask: int = 1, data: Union[str, bytes] = "") -> None: 

138 """ 

139 Constructor for ABNF. Please check RFC for arguments. 

140 """ 

141 self.fin = fin 

142 self.rsv1 = rsv1 

143 self.rsv2 = rsv2 

144 self.rsv3 = rsv3 

145 self.opcode = opcode 

146 self.mask = mask 

147 if data is None: 

148 data = "" 

149 self.data = data 

150 self.get_mask_key = os.urandom 

151 

152 def validate(self, skip_utf8_validation: bool = False) -> None: 

153 """ 

154 Validate the ABNF frame. 

155 

156 Parameters 

157 ---------- 

158 skip_utf8_validation: skip utf8 validation. 

159 """ 

160 if self.rsv1 or self.rsv2 or self.rsv3: 

161 raise WebSocketProtocolException("rsv is not implemented, yet") 

162 

163 if self.opcode not in ABNF.OPCODES: 

164 raise WebSocketProtocolException("Invalid opcode %r", self.opcode) 

165 

166 if self.opcode == ABNF.OPCODE_PING and not self.fin: 

167 raise WebSocketProtocolException("Invalid ping frame.") 

168 

169 if self.opcode == ABNF.OPCODE_CLOSE: 

170 l = len(self.data) 

171 if not l: 

172 return 

173 if l == 1 or l >= 126: 

174 raise WebSocketProtocolException("Invalid close frame.") 

175 if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): 

176 raise WebSocketProtocolException("Invalid close frame.") 

177 

178 code = 256 * self.data[0] + self.data[1] 

179 if not self._is_valid_close_status(code): 

180 raise WebSocketProtocolException("Invalid close opcode %r", code) 

181 

182 @staticmethod 

183 def _is_valid_close_status(code: int) -> bool: 

184 return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) 

185 

186 def __str__(self) -> str: 

187 return "fin=" + str(self.fin) \ 

188 + " opcode=" + str(self.opcode) \ 

189 + " data=" + str(self.data) 

190 

191 @staticmethod 

192 def create_frame(data: Union[bytes, str], opcode: int, fin: int = 1) -> 'ABNF': 

193 """ 

194 Create frame to send text, binary and other data. 

195 

196 Parameters 

197 ---------- 

198 data: str 

199 data to send. This is string value(byte array). 

200 If opcode is OPCODE_TEXT and this value is unicode, 

201 data value is converted into unicode string, automatically. 

202 opcode: int 

203 operation code. please see OPCODE_MAP. 

204 fin: int 

205 fin flag. if set to 0, create continue fragmentation. 

206 """ 

207 if opcode == ABNF.OPCODE_TEXT and isinstance(data, str): 

208 data = data.encode("utf-8") 

209 # mask must be set if send data from client 

210 return ABNF(fin, 0, 0, 0, opcode, 1, data) 

211 

212 def format(self) -> bytes: 

213 """ 

214 Format this object to string(byte array) to send data to server. 

215 """ 

216 if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]): 

217 raise ValueError("not 0 or 1") 

218 if self.opcode not in ABNF.OPCODES: 

219 raise ValueError("Invalid OPCODE") 

220 length = len(self.data) 

221 if length >= ABNF.LENGTH_63: 

222 raise ValueError("data is too long") 

223 

224 frame_header = chr(self.fin << 7 | 

225 self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 | 

226 self.opcode).encode('latin-1') 

227 if length < ABNF.LENGTH_7: 

228 frame_header += chr(self.mask << 7 | length).encode('latin-1') 

229 elif length < ABNF.LENGTH_16: 

230 frame_header += chr(self.mask << 7 | 0x7e).encode('latin-1') 

231 frame_header += struct.pack("!H", length) 

232 else: 

233 frame_header += chr(self.mask << 7 | 0x7f).encode('latin-1') 

234 frame_header += struct.pack("!Q", length) 

235 

236 if not self.mask: 

237 return frame_header + self.data 

238 else: 

239 mask_key = self.get_mask_key(4) 

240 return frame_header + self._get_masked(mask_key) 

241 

242 def _get_masked(self, mask_key: Union[str, bytes]) -> bytes: 

243 s = ABNF.mask(mask_key, self.data) 

244 

245 if isinstance(mask_key, str): 

246 mask_key = mask_key.encode('utf-8') 

247 

248 return mask_key + s 

249 

250 @staticmethod 

251 def mask(mask_key: Union[str, bytes], data: Union[str, bytes]) -> bytes: 

252 """ 

253 Mask or unmask data. Just do xor for each byte 

254 

255 Parameters 

256 ---------- 

257 mask_key: bytes or str 

258 4 byte mask. 

259 data: bytes or str 

260 data to mask/unmask. 

261 """ 

262 if data is None: 

263 data = "" 

264 

265 if isinstance(mask_key, str): 

266 mask_key = mask_key.encode('latin-1') 

267 

268 if isinstance(data, str): 

269 data = data.encode('latin-1') 

270 

271 return _mask(array.array("B", mask_key), array.array("B", data)) 

272 

273 

274class frame_buffer: 

275 _HEADER_MASK_INDEX = 5 

276 _HEADER_LENGTH_INDEX = 6 

277 

278 def __init__(self, recv_fn: Callable[[int], int], skip_utf8_validation: bool) -> None: 

279 self.recv = recv_fn 

280 self.skip_utf8_validation = skip_utf8_validation 

281 # Buffers over the packets from the layer beneath until desired amount 

282 # bytes of bytes are received. 

283 self.recv_buffer = [] 

284 self.clear() 

285 self.lock = Lock() 

286 

287 def clear(self) -> None: 

288 self.header = None 

289 self.length = None 

290 self.mask = None 

291 

292 def has_received_header(self) -> bool: 

293 return self.header is None 

294 

295 def recv_header(self) -> None: 

296 header = self.recv_strict(2) 

297 b1 = header[0] 

298 fin = b1 >> 7 & 1 

299 rsv1 = b1 >> 6 & 1 

300 rsv2 = b1 >> 5 & 1 

301 rsv3 = b1 >> 4 & 1 

302 opcode = b1 & 0xf 

303 b2 = header[1] 

304 has_mask = b2 >> 7 & 1 

305 length_bits = b2 & 0x7f 

306 

307 self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits) 

308 

309 def has_mask(self) -> Union[bool, int]: 

310 if not self.header: 

311 return False 

312 return self.header[frame_buffer._HEADER_MASK_INDEX] 

313 

314 def has_received_length(self) -> bool: 

315 return self.length is None 

316 

317 def recv_length(self) -> None: 

318 bits = self.header[frame_buffer._HEADER_LENGTH_INDEX] 

319 length_bits = bits & 0x7f 

320 if length_bits == 0x7e: 

321 v = self.recv_strict(2) 

322 self.length = struct.unpack("!H", v)[0] 

323 elif length_bits == 0x7f: 

324 v = self.recv_strict(8) 

325 self.length = struct.unpack("!Q", v)[0] 

326 else: 

327 self.length = length_bits 

328 

329 def has_received_mask(self) -> bool: 

330 return self.mask is None 

331 

332 def recv_mask(self) -> None: 

333 self.mask = self.recv_strict(4) if self.has_mask() else "" 

334 

335 def recv_frame(self) -> ABNF: 

336 

337 with self.lock: 

338 # Header 

339 if self.has_received_header(): 

340 self.recv_header() 

341 (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header 

342 

343 # Frame length 

344 if self.has_received_length(): 

345 self.recv_length() 

346 length = self.length 

347 

348 # Mask 

349 if self.has_received_mask(): 

350 self.recv_mask() 

351 mask = self.mask 

352 

353 # Payload 

354 payload = self.recv_strict(length) 

355 if has_mask: 

356 payload = ABNF.mask(mask, payload) 

357 

358 # Reset for next frame 

359 self.clear() 

360 

361 frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) 

362 frame.validate(self.skip_utf8_validation) 

363 

364 return frame 

365 

366 def recv_strict(self, bufsize: int) -> bytes: 

367 shortage = bufsize - sum(map(len, self.recv_buffer)) 

368 while shortage > 0: 

369 # Limit buffer size that we pass to socket.recv() to avoid 

370 # fragmenting the heap -- the number of bytes recv() actually 

371 # reads is limited by socket buffer and is relatively small, 

372 # yet passing large numbers repeatedly causes lots of large 

373 # buffers allocated and then shrunk, which results in 

374 # fragmentation. 

375 bytes_ = self.recv(min(16384, shortage)) 

376 self.recv_buffer.append(bytes_) 

377 shortage -= len(bytes_) 

378 

379 unified = b"".join(self.recv_buffer) 

380 

381 if shortage == 0: 

382 self.recv_buffer = [] 

383 return unified 

384 else: 

385 self.recv_buffer = [unified[bufsize:]] 

386 return unified[:bufsize] 

387 

388 

389class continuous_frame: 

390 

391 def __init__(self, fire_cont_frame: bool, skip_utf8_validation: bool) -> None: 

392 self.fire_cont_frame = fire_cont_frame 

393 self.skip_utf8_validation = skip_utf8_validation 

394 self.cont_data = None 

395 self.recving_frames = None 

396 

397 def validate(self, frame: ABNF) -> None: 

398 if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: 

399 raise WebSocketProtocolException("Illegal frame") 

400 if self.recving_frames and \ 

401 frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): 

402 raise WebSocketProtocolException("Illegal frame") 

403 

404 def add(self, frame: ABNF) -> None: 

405 if self.cont_data: 

406 self.cont_data[1] += frame.data 

407 else: 

408 if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): 

409 self.recving_frames = frame.opcode 

410 self.cont_data = [frame.opcode, frame.data] 

411 

412 if frame.fin: 

413 self.recving_frames = None 

414 

415 def is_fire(self, frame: ABNF) -> Union[bool, int]: 

416 return frame.fin or self.fire_cont_frame 

417 

418 def extract(self, frame: ABNF) -> list: 

419 data = self.cont_data 

420 self.cont_data = None 

421 frame.data = data[1] 

422 if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data): 

423 raise WebSocketPayloadException( 

424 "cannot decode: " + repr(frame.data)) 

425 

426 return [data[0], frame]