Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pymysql/protocol.py: 31%

174 statements  

« prev     ^ index     » next       coverage.py v7.3.1, created at 2023-09-25 06:28 +0000

1# Python implementation of low level MySQL client-server protocol 

2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html 

3 

4from .charset import MBLENGTH 

5from .constants import FIELD_TYPE, SERVER_STATUS 

6from . import err 

7 

8import struct 

9import sys 

10 

11 

12DEBUG = False 

13 

14NULL_COLUMN = 251 

15UNSIGNED_CHAR_COLUMN = 251 

16UNSIGNED_SHORT_COLUMN = 252 

17UNSIGNED_INT24_COLUMN = 253 

18UNSIGNED_INT64_COLUMN = 254 

19 

20 

21def dump_packet(data): # pragma: no cover 

22 def printable(data): 

23 if 32 <= data < 127: 

24 return chr(data) 

25 return "." 

26 

27 try: 

28 print("packet length:", len(data)) 

29 for i in range(1, 7): 

30 f = sys._getframe(i) 

31 print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno)) 

32 print("-" * 66) 

33 except ValueError: 

34 pass 

35 dump_data = [data[i : i + 16] for i in range(0, min(len(data), 256), 16)] 

36 for d in dump_data: 

37 print( 

38 " ".join(f"{x:02X}" for x in d) 

39 + " " * (16 - len(d)) 

40 + " " * 2 

41 + "".join(printable(x) for x in d) 

42 ) 

43 print("-" * 66) 

44 print() 

45 

46 

47class MysqlPacket: 

48 """Representation of a MySQL response packet. 

49 

50 Provides an interface for reading/parsing the packet results. 

51 """ 

52 

53 __slots__ = ("_position", "_data") 

54 

55 def __init__(self, data, encoding): 

56 self._position = 0 

57 self._data = data 

58 

59 def get_all_data(self): 

60 return self._data 

61 

62 def read(self, size): 

63 """Read the first 'size' bytes in packet and advance cursor past them.""" 

64 result = self._data[self._position : (self._position + size)] 

65 if len(result) != size: 

66 error = ( 

67 "Result length not requested length:\n" 

68 "Expected=%s. Actual=%s. Position: %s. Data Length: %s" 

69 % (size, len(result), self._position, len(self._data)) 

70 ) 

71 if DEBUG: 

72 print(error) 

73 self.dump() 

74 raise AssertionError(error) 

75 self._position += size 

76 return result 

77 

78 def read_all(self): 

79 """Read all remaining data in the packet. 

80 

81 (Subsequent read() will return errors.) 

82 """ 

83 result = self._data[self._position :] 

84 self._position = None # ensure no subsequent read() 

85 return result 

86 

87 def advance(self, length): 

88 """Advance the cursor in data buffer 'length' bytes.""" 

89 new_position = self._position + length 

90 if new_position < 0 or new_position > len(self._data): 

91 raise Exception( 

92 "Invalid advance amount (%s) for cursor. " 

93 "Position=%s" % (length, new_position) 

94 ) 

95 self._position = new_position 

96 

97 def rewind(self, position=0): 

98 """Set the position of the data buffer cursor to 'position'.""" 

99 if position < 0 or position > len(self._data): 

100 raise Exception("Invalid position to rewind cursor to: %s." % position) 

101 self._position = position 

102 

103 def get_bytes(self, position, length=1): 

104 """Get 'length' bytes starting at 'position'. 

105 

106 Position is start of payload (first four packet header bytes are not 

107 included) starting at index '0'. 

108 

109 No error checking is done. If requesting outside end of buffer 

110 an empty string (or string shorter than 'length') may be returned! 

111 """ 

112 return self._data[position : (position + length)] 

113 

114 def read_uint8(self): 

115 result = self._data[self._position] 

116 self._position += 1 

117 return result 

118 

119 def read_uint16(self): 

120 result = struct.unpack_from("<H", self._data, self._position)[0] 

121 self._position += 2 

122 return result 

123 

124 def read_uint24(self): 

125 low, high = struct.unpack_from("<HB", self._data, self._position) 

126 self._position += 3 

127 return low + (high << 16) 

128 

129 def read_uint32(self): 

130 result = struct.unpack_from("<I", self._data, self._position)[0] 

131 self._position += 4 

132 return result 

133 

134 def read_uint64(self): 

135 result = struct.unpack_from("<Q", self._data, self._position)[0] 

136 self._position += 8 

137 return result 

138 

139 def read_string(self): 

140 end_pos = self._data.find(b"\0", self._position) 

141 if end_pos < 0: 

142 return None 

143 result = self._data[self._position : end_pos] 

144 self._position = end_pos + 1 

145 return result 

146 

147 def read_length_encoded_integer(self): 

148 """Read a 'Length Coded Binary' number from the data buffer. 

149 

150 Length coded numbers can be anywhere from 1 to 9 bytes depending 

151 on the value of the first byte. 

152 """ 

153 c = self.read_uint8() 

154 if c == NULL_COLUMN: 

155 return None 

156 if c < UNSIGNED_CHAR_COLUMN: 

157 return c 

158 elif c == UNSIGNED_SHORT_COLUMN: 

159 return self.read_uint16() 

160 elif c == UNSIGNED_INT24_COLUMN: 

161 return self.read_uint24() 

162 elif c == UNSIGNED_INT64_COLUMN: 

163 return self.read_uint64() 

164 

165 def read_length_coded_string(self): 

166 """Read a 'Length Coded String' from the data buffer. 

167 

168 A 'Length Coded String' consists first of a length coded 

169 (unsigned, positive) integer represented in 1-9 bytes followed by 

170 that many bytes of binary data. (For example "cat" would be "3cat".) 

171 """ 

172 length = self.read_length_encoded_integer() 

173 if length is None: 

174 return None 

175 return self.read(length) 

176 

177 def read_struct(self, fmt): 

178 s = struct.Struct(fmt) 

179 result = s.unpack_from(self._data, self._position) 

180 self._position += s.size 

181 return result 

182 

183 def is_ok_packet(self): 

184 # https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html 

185 return self._data[0] == 0 and len(self._data) >= 7 

186 

187 def is_eof_packet(self): 

188 # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet 

189 # Caution: \xFE may be LengthEncodedInteger. 

190 # If \xFE is LengthEncodedInteger header, 8bytes followed. 

191 return self._data[0] == 0xFE and len(self._data) < 9 

192 

193 def is_auth_switch_request(self): 

194 # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest 

195 return self._data[0] == 0xFE 

196 

197 def is_extra_auth_data(self): 

198 # https://dev.mysql.com/doc/internals/en/successful-authentication.html 

199 return self._data[0] == 1 

200 

201 def is_resultset_packet(self): 

202 field_count = self._data[0] 

203 return 1 <= field_count <= 250 

204 

205 def is_load_local_packet(self): 

206 return self._data[0] == 0xFB 

207 

208 def is_error_packet(self): 

209 return self._data[0] == 0xFF 

210 

211 def check_error(self): 

212 if self.is_error_packet(): 

213 self.raise_for_error() 

214 

215 def raise_for_error(self): 

216 self.rewind() 

217 self.advance(1) # field_count == error (we already know that) 

218 errno = self.read_uint16() 

219 if DEBUG: 

220 print("errno =", errno) 

221 err.raise_mysql_exception(self._data) 

222 

223 def dump(self): 

224 dump_packet(self._data) 

225 

226 

227class FieldDescriptorPacket(MysqlPacket): 

228 """A MysqlPacket that represents a specific column's metadata in the result. 

229 

230 Parsing is automatically done and the results are exported via public 

231 attributes on the class such as: db, table_name, name, length, type_code. 

232 """ 

233 

234 def __init__(self, data, encoding): 

235 MysqlPacket.__init__(self, data, encoding) 

236 self._parse_field_descriptor(encoding) 

237 

238 def _parse_field_descriptor(self, encoding): 

239 """Parse the 'Field Descriptor' (Metadata) packet. 

240 

241 This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0). 

242 """ 

243 self.catalog = self.read_length_coded_string() 

244 self.db = self.read_length_coded_string() 

245 self.table_name = self.read_length_coded_string().decode(encoding) 

246 self.org_table = self.read_length_coded_string().decode(encoding) 

247 self.name = self.read_length_coded_string().decode(encoding) 

248 self.org_name = self.read_length_coded_string().decode(encoding) 

249 ( 

250 self.charsetnr, 

251 self.length, 

252 self.type_code, 

253 self.flags, 

254 self.scale, 

255 ) = self.read_struct("<xHIBHBxx") 

256 # 'default' is a length coded binary and is still in the buffer? 

257 # not used for normal result sets... 

258 

259 def description(self): 

260 """Provides a 7-item tuple compatible with the Python PEP249 DB Spec.""" 

261 return ( 

262 self.name, 

263 self.type_code, 

264 None, # TODO: display_length; should this be self.length? 

265 self.get_column_length(), # 'internal_size' 

266 self.get_column_length(), # 'precision' # TODO: why!?!? 

267 self.scale, 

268 self.flags % 2 == 0, 

269 ) 

270 

271 def get_column_length(self): 

272 if self.type_code == FIELD_TYPE.VAR_STRING: 

273 mblen = MBLENGTH.get(self.charsetnr, 1) 

274 return self.length // mblen 

275 return self.length 

276 

277 def __str__(self): 

278 return "{} {!r}.{!r}.{!r}, type={}, flags={:x}".format( 

279 self.__class__, 

280 self.db, 

281 self.table_name, 

282 self.name, 

283 self.type_code, 

284 self.flags, 

285 ) 

286 

287 

288class OKPacketWrapper: 

289 """ 

290 OK Packet Wrapper. It uses an existing packet object, and wraps 

291 around it, exposing useful variables while still providing access 

292 to the original packet objects variables and methods. 

293 """ 

294 

295 def __init__(self, from_packet): 

296 if not from_packet.is_ok_packet(): 

297 raise ValueError( 

298 "Cannot create " 

299 + str(self.__class__.__name__) 

300 + " object from invalid packet type" 

301 ) 

302 

303 self.packet = from_packet 

304 self.packet.advance(1) 

305 

306 self.affected_rows = self.packet.read_length_encoded_integer() 

307 self.insert_id = self.packet.read_length_encoded_integer() 

308 self.server_status, self.warning_count = self.read_struct("<HH") 

309 self.message = self.packet.read_all() 

310 self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS 

311 

312 def __getattr__(self, key): 

313 return getattr(self.packet, key) 

314 

315 

316class EOFPacketWrapper: 

317 """ 

318 EOF Packet Wrapper. It uses an existing packet object, and wraps 

319 around it, exposing useful variables while still providing access 

320 to the original packet objects variables and methods. 

321 """ 

322 

323 def __init__(self, from_packet): 

324 if not from_packet.is_eof_packet(): 

325 raise ValueError( 

326 f"Cannot create '{self.__class__}' object from invalid packet type" 

327 ) 

328 

329 self.packet = from_packet 

330 self.warning_count, self.server_status = self.packet.read_struct("<xhh") 

331 if DEBUG: 

332 print("server_status=", self.server_status) 

333 self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS 

334 

335 def __getattr__(self, key): 

336 return getattr(self.packet, key) 

337 

338 

339class LoadLocalPacketWrapper: 

340 """ 

341 Load Local Packet Wrapper. It uses an existing packet object, and wraps 

342 around it, exposing useful variables while still providing access 

343 to the original packet objects variables and methods. 

344 """ 

345 

346 def __init__(self, from_packet): 

347 if not from_packet.is_load_local_packet(): 

348 raise ValueError( 

349 f"Cannot create '{self.__class__}' object from invalid packet type" 

350 ) 

351 

352 self.packet = from_packet 

353 self.filename = self.packet.get_all_data()[1:] 

354 if DEBUG: 

355 print("filename=", self.filename) 

356 

357 def __getattr__(self, key): 

358 return getattr(self.packet, key)