Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/pymysql/protocol.py: 31%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

175 statements  

1# Python implementation of low level MySQL client-server protocol 

2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html 

3 

4from .charset import MBLENGTH 

5from .constants import FIELD_TYPE, SERVER_STATUS 

6from . import err 

7 

8import struct 

9import sys 

10 

11 

12DEBUG = False 

13 

14NULL_COLUMN = 251 

15UNSIGNED_CHAR_COLUMN = 251 

16UNSIGNED_SHORT_COLUMN = 252 

17UNSIGNED_INT24_COLUMN = 253 

18UNSIGNED_INT64_COLUMN = 254 

19 

20 

21def dump_packet(data): # pragma: no cover 

22 def printable(data): 

23 if 32 <= data < 127: 

24 return chr(data) 

25 return "." 

26 

27 try: 

28 print("packet length:", len(data)) 

29 for i in range(1, 7): 

30 f = sys._getframe(i) 

31 print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno)) 

32 print("-" * 66) 

33 except ValueError: 

34 pass 

35 dump_data = [data[i : i + 16] for i in range(0, min(len(data), 256), 16)] 

36 for d in dump_data: 

37 print( 

38 " ".join(f"{x:02X}" for x in d) 

39 + " " * (16 - len(d)) 

40 + " " * 2 

41 + "".join(printable(x) for x in d) 

42 ) 

43 print("-" * 66) 

44 print() 

45 

46 

47class MysqlPacket: 

48 """Representation of a MySQL response packet. 

49 

50 Provides an interface for reading/parsing the packet results. 

51 """ 

52 

53 __slots__ = ("_position", "_data") 

54 

55 def __init__(self, data, encoding): 

56 self._position = 0 

57 self._data = data 

58 

59 def get_all_data(self): 

60 return self._data 

61 

62 def read(self, size): 

63 """Read the first 'size' bytes in packet and advance cursor past them.""" 

64 result = self._data[self._position : (self._position + size)] 

65 if len(result) != size: 

66 error = ( 

67 "Result length not requested length:\n" 

68 f"Expected={size}. Actual={len(result)}. Position: {self._position}. Data Length: {len(self._data)}" 

69 ) 

70 if DEBUG: 

71 print(error) 

72 self.dump() 

73 raise AssertionError(error) 

74 self._position += size 

75 return result 

76 

77 def read_all(self): 

78 """Read all remaining data in the packet. 

79 

80 (Subsequent read() will return errors.) 

81 """ 

82 result = self._data[self._position :] 

83 self._position = None # ensure no subsequent read() 

84 return result 

85 

86 def advance(self, length): 

87 """Advance the cursor in data buffer 'length' bytes.""" 

88 new_position = self._position + length 

89 if new_position < 0 or new_position > len(self._data): 

90 raise Exception( 

91 f"Invalid advance amount ({length}) for cursor. Position={new_position}" 

92 ) 

93 self._position = new_position 

94 

95 def rewind(self, position=0): 

96 """Set the position of the data buffer cursor to 'position'.""" 

97 if position < 0 or position > len(self._data): 

98 raise Exception("Invalid position to rewind cursor to: %s." % position) 

99 self._position = position 

100 

101 def get_bytes(self, position, length=1): 

102 """Get 'length' bytes starting at 'position'. 

103 

104 Position is start of payload (first four packet header bytes are not 

105 included) starting at index '0'. 

106 

107 No error checking is done. If requesting outside end of buffer 

108 an empty string (or string shorter than 'length') may be returned! 

109 """ 

110 return self._data[position : (position + length)] 

111 

112 def read_uint8(self): 

113 result = self._data[self._position] 

114 self._position += 1 

115 return result 

116 

117 def read_uint16(self): 

118 result = struct.unpack_from("<H", self._data, self._position)[0] 

119 self._position += 2 

120 return result 

121 

122 def read_uint24(self): 

123 low, high = struct.unpack_from("<HB", self._data, self._position) 

124 self._position += 3 

125 return low + (high << 16) 

126 

127 def read_uint32(self): 

128 result = struct.unpack_from("<I", self._data, self._position)[0] 

129 self._position += 4 

130 return result 

131 

132 def read_uint64(self): 

133 result = struct.unpack_from("<Q", self._data, self._position)[0] 

134 self._position += 8 

135 return result 

136 

137 def read_string(self): 

138 end_pos = self._data.find(b"\0", self._position) 

139 if end_pos < 0: 

140 return None 

141 result = self._data[self._position : end_pos] 

142 self._position = end_pos + 1 

143 return result 

144 

145 def read_length_encoded_integer(self): 

146 """Read a 'Length Coded Binary' number from the data buffer. 

147 

148 Length coded numbers can be anywhere from 1 to 9 bytes depending 

149 on the value of the first byte. 

150 """ 

151 c = self.read_uint8() 

152 if c == NULL_COLUMN: 

153 return None 

154 if c < UNSIGNED_CHAR_COLUMN: 

155 return c 

156 elif c == UNSIGNED_SHORT_COLUMN: 

157 return self.read_uint16() 

158 elif c == UNSIGNED_INT24_COLUMN: 

159 return self.read_uint24() 

160 elif c == UNSIGNED_INT64_COLUMN: 

161 return self.read_uint64() 

162 

163 def read_length_coded_string(self): 

164 """Read a 'Length Coded String' from the data buffer. 

165 

166 A 'Length Coded String' consists first of a length coded 

167 (unsigned, positive) integer represented in 1-9 bytes followed by 

168 that many bytes of binary data. (For example "cat" would be "3cat".) 

169 """ 

170 length = self.read_length_encoded_integer() 

171 if length is None: 

172 return None 

173 return self.read(length) 

174 

175 def read_struct(self, fmt): 

176 s = struct.Struct(fmt) 

177 result = s.unpack_from(self._data, self._position) 

178 self._position += s.size 

179 return result 

180 

181 def is_ok_packet(self): 

182 # https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html 

183 return self._data[0] == 0 and len(self._data) >= 7 

184 

185 def is_eof_packet(self): 

186 # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet 

187 # Caution: \xFE may be LengthEncodedInteger. 

188 # If \xFE is LengthEncodedInteger header, 8bytes followed. 

189 return self._data[0] == 0xFE and len(self._data) < 9 

190 

191 def is_auth_switch_request(self): 

192 # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest 

193 return self._data[0] == 0xFE 

194 

195 def is_extra_auth_data(self): 

196 # https://dev.mysql.com/doc/internals/en/successful-authentication.html 

197 return self._data[0] == 1 

198 

199 def is_resultset_packet(self): 

200 field_count = self._data[0] 

201 return 1 <= field_count <= 250 

202 

203 def is_load_local_packet(self): 

204 return self._data[0] == 0xFB 

205 

206 def is_error_packet(self): 

207 return self._data[0] == 0xFF 

208 

209 def check_error(self): 

210 if self.is_error_packet(): 

211 self.raise_for_error() 

212 

213 def raise_for_error(self): 

214 self.rewind() 

215 self.advance(1) # field_count == error (we already know that) 

216 errno = self.read_uint16() 

217 if DEBUG: 

218 print("errno =", errno) 

219 err.raise_mysql_exception(self._data) 

220 

221 def dump(self): 

222 dump_packet(self._data) 

223 

224 

225class FieldDescriptorPacket(MysqlPacket): 

226 """A MysqlPacket that represents a specific column's metadata in the result. 

227 

228 Parsing is automatically done and the results are exported via public 

229 attributes on the class such as: db, table_name, name, length, type_code. 

230 """ 

231 

232 def __init__(self, data, encoding): 

233 MysqlPacket.__init__(self, data, encoding) 

234 self._parse_field_descriptor(encoding) 

235 

236 def _parse_field_descriptor(self, encoding): 

237 """Parse the 'Field Descriptor' (Metadata) packet. 

238 

239 This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0). 

240 """ 

241 self.catalog = self.read_length_coded_string() 

242 self.db = self.read_length_coded_string() 

243 self.table_name = self.read_length_coded_string().decode(encoding) 

244 self.org_table = self.read_length_coded_string().decode(encoding) 

245 self.name = self.read_length_coded_string().decode(encoding) 

246 self.org_name = self.read_length_coded_string().decode(encoding) 

247 ( 

248 self.charsetnr, 

249 self.length, 

250 self.type_code, 

251 self.flags, 

252 self.scale, 

253 ) = self.read_struct("<xHIBHBxx") 

254 # 'default' is a length coded binary and is still in the buffer? 

255 # not used for normal result sets... 

256 

257 def description(self): 

258 """Provides a 7-item tuple compatible with the Python PEP249 DB Spec.""" 

259 return ( 

260 self.name, 

261 self.type_code, 

262 None, # TODO: display_length; should this be self.length? 

263 self.get_column_length(), # 'internal_size' 

264 self.get_column_length(), # 'precision' # TODO: why!?!? 

265 self.scale, 

266 self.flags % 2 == 0, 

267 ) 

268 

269 def get_column_length(self): 

270 if self.type_code == FIELD_TYPE.VAR_STRING: 

271 mblen = MBLENGTH.get(self.charsetnr, 1) 

272 return self.length // mblen 

273 return self.length 

274 

275 def __str__(self): 

276 return "{} {!r}.{!r}.{!r}, type={}, flags={:x}".format( 

277 self.__class__, 

278 self.db, 

279 self.table_name, 

280 self.name, 

281 self.type_code, 

282 self.flags, 

283 ) 

284 

285 

286class OKPacketWrapper: 

287 """ 

288 OK Packet Wrapper. It uses an existing packet object, and wraps 

289 around it, exposing useful variables while still providing access 

290 to the original packet objects variables and methods. 

291 """ 

292 

293 def __init__(self, from_packet): 

294 if not from_packet.is_ok_packet(): 

295 raise ValueError( 

296 "Cannot create " 

297 + str(self.__class__.__name__) 

298 + " object from invalid packet type" 

299 ) 

300 

301 self.packet = from_packet 

302 self.packet.advance(1) 

303 

304 self.affected_rows = self.packet.read_length_encoded_integer() 

305 self.insert_id = self.packet.read_length_encoded_integer() 

306 self.server_status, self.warning_count = self.read_struct("<HH") 

307 self.message = self.packet.read_all() 

308 self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS 

309 

310 def __getattr__(self, key): 

311 return getattr(self.packet, key) 

312 

313 

314class EOFPacketWrapper: 

315 """ 

316 EOF Packet Wrapper. It uses an existing packet object, and wraps 

317 around it, exposing useful variables while still providing access 

318 to the original packet objects variables and methods. 

319 """ 

320 

321 def __init__(self, from_packet): 

322 if not from_packet.is_eof_packet(): 

323 raise ValueError( 

324 f"Cannot create '{self.__class__}' object from invalid packet type" 

325 ) 

326 

327 self.packet = from_packet 

328 self.warning_count, self.server_status = self.packet.read_struct("<xhh") 

329 if DEBUG: 

330 print("server_status=", self.server_status) 

331 self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS 

332 

333 def __getattr__(self, key): 

334 return getattr(self.packet, key) 

335 

336 

337class LoadLocalPacketWrapper: 

338 """ 

339 Load Local Packet Wrapper. It uses an existing packet object, and wraps 

340 around it, exposing useful variables while still providing access 

341 to the original packet objects variables and methods. 

342 """ 

343 

344 def __init__(self, from_packet): 

345 if not from_packet.is_load_local_packet(): 

346 raise ValueError( 

347 f"Cannot create '{self.__class__}' object from invalid packet type" 

348 ) 

349 

350 self.packet = from_packet 

351 self.filename = self.packet.get_all_data()[1:] 

352 if DEBUG: 

353 print("filename=", self.filename) 

354 

355 def __getattr__(self, key): 

356 return getattr(self.packet, key)