Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/pymysql/protocol.py: 31%
174 statements
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:28 +0000
« prev ^ index » next coverage.py v7.3.1, created at 2023-09-25 06:28 +0000
1# Python implementation of low level MySQL client-server protocol
2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
4from .charset import MBLENGTH
5from .constants import FIELD_TYPE, SERVER_STATUS
6from . import err
8import struct
9import sys
12DEBUG = False
14NULL_COLUMN = 251
15UNSIGNED_CHAR_COLUMN = 251
16UNSIGNED_SHORT_COLUMN = 252
17UNSIGNED_INT24_COLUMN = 253
18UNSIGNED_INT64_COLUMN = 254
21def dump_packet(data): # pragma: no cover
22 def printable(data):
23 if 32 <= data < 127:
24 return chr(data)
25 return "."
27 try:
28 print("packet length:", len(data))
29 for i in range(1, 7):
30 f = sys._getframe(i)
31 print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno))
32 print("-" * 66)
33 except ValueError:
34 pass
35 dump_data = [data[i : i + 16] for i in range(0, min(len(data), 256), 16)]
36 for d in dump_data:
37 print(
38 " ".join(f"{x:02X}" for x in d)
39 + " " * (16 - len(d))
40 + " " * 2
41 + "".join(printable(x) for x in d)
42 )
43 print("-" * 66)
44 print()
47class MysqlPacket:
48 """Representation of a MySQL response packet.
50 Provides an interface for reading/parsing the packet results.
51 """
53 __slots__ = ("_position", "_data")
55 def __init__(self, data, encoding):
56 self._position = 0
57 self._data = data
59 def get_all_data(self):
60 return self._data
62 def read(self, size):
63 """Read the first 'size' bytes in packet and advance cursor past them."""
64 result = self._data[self._position : (self._position + size)]
65 if len(result) != size:
66 error = (
67 "Result length not requested length:\n"
68 "Expected=%s. Actual=%s. Position: %s. Data Length: %s"
69 % (size, len(result), self._position, len(self._data))
70 )
71 if DEBUG:
72 print(error)
73 self.dump()
74 raise AssertionError(error)
75 self._position += size
76 return result
78 def read_all(self):
79 """Read all remaining data in the packet.
81 (Subsequent read() will return errors.)
82 """
83 result = self._data[self._position :]
84 self._position = None # ensure no subsequent read()
85 return result
87 def advance(self, length):
88 """Advance the cursor in data buffer 'length' bytes."""
89 new_position = self._position + length
90 if new_position < 0 or new_position > len(self._data):
91 raise Exception(
92 "Invalid advance amount (%s) for cursor. "
93 "Position=%s" % (length, new_position)
94 )
95 self._position = new_position
97 def rewind(self, position=0):
98 """Set the position of the data buffer cursor to 'position'."""
99 if position < 0 or position > len(self._data):
100 raise Exception("Invalid position to rewind cursor to: %s." % position)
101 self._position = position
103 def get_bytes(self, position, length=1):
104 """Get 'length' bytes starting at 'position'.
106 Position is start of payload (first four packet header bytes are not
107 included) starting at index '0'.
109 No error checking is done. If requesting outside end of buffer
110 an empty string (or string shorter than 'length') may be returned!
111 """
112 return self._data[position : (position + length)]
114 def read_uint8(self):
115 result = self._data[self._position]
116 self._position += 1
117 return result
119 def read_uint16(self):
120 result = struct.unpack_from("<H", self._data, self._position)[0]
121 self._position += 2
122 return result
124 def read_uint24(self):
125 low, high = struct.unpack_from("<HB", self._data, self._position)
126 self._position += 3
127 return low + (high << 16)
129 def read_uint32(self):
130 result = struct.unpack_from("<I", self._data, self._position)[0]
131 self._position += 4
132 return result
134 def read_uint64(self):
135 result = struct.unpack_from("<Q", self._data, self._position)[0]
136 self._position += 8
137 return result
139 def read_string(self):
140 end_pos = self._data.find(b"\0", self._position)
141 if end_pos < 0:
142 return None
143 result = self._data[self._position : end_pos]
144 self._position = end_pos + 1
145 return result
147 def read_length_encoded_integer(self):
148 """Read a 'Length Coded Binary' number from the data buffer.
150 Length coded numbers can be anywhere from 1 to 9 bytes depending
151 on the value of the first byte.
152 """
153 c = self.read_uint8()
154 if c == NULL_COLUMN:
155 return None
156 if c < UNSIGNED_CHAR_COLUMN:
157 return c
158 elif c == UNSIGNED_SHORT_COLUMN:
159 return self.read_uint16()
160 elif c == UNSIGNED_INT24_COLUMN:
161 return self.read_uint24()
162 elif c == UNSIGNED_INT64_COLUMN:
163 return self.read_uint64()
165 def read_length_coded_string(self):
166 """Read a 'Length Coded String' from the data buffer.
168 A 'Length Coded String' consists first of a length coded
169 (unsigned, positive) integer represented in 1-9 bytes followed by
170 that many bytes of binary data. (For example "cat" would be "3cat".)
171 """
172 length = self.read_length_encoded_integer()
173 if length is None:
174 return None
175 return self.read(length)
177 def read_struct(self, fmt):
178 s = struct.Struct(fmt)
179 result = s.unpack_from(self._data, self._position)
180 self._position += s.size
181 return result
183 def is_ok_packet(self):
184 # https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
185 return self._data[0] == 0 and len(self._data) >= 7
187 def is_eof_packet(self):
188 # http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
189 # Caution: \xFE may be LengthEncodedInteger.
190 # If \xFE is LengthEncodedInteger header, 8bytes followed.
191 return self._data[0] == 0xFE and len(self._data) < 9
193 def is_auth_switch_request(self):
194 # http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
195 return self._data[0] == 0xFE
197 def is_extra_auth_data(self):
198 # https://dev.mysql.com/doc/internals/en/successful-authentication.html
199 return self._data[0] == 1
201 def is_resultset_packet(self):
202 field_count = self._data[0]
203 return 1 <= field_count <= 250
205 def is_load_local_packet(self):
206 return self._data[0] == 0xFB
208 def is_error_packet(self):
209 return self._data[0] == 0xFF
211 def check_error(self):
212 if self.is_error_packet():
213 self.raise_for_error()
215 def raise_for_error(self):
216 self.rewind()
217 self.advance(1) # field_count == error (we already know that)
218 errno = self.read_uint16()
219 if DEBUG:
220 print("errno =", errno)
221 err.raise_mysql_exception(self._data)
223 def dump(self):
224 dump_packet(self._data)
227class FieldDescriptorPacket(MysqlPacket):
228 """A MysqlPacket that represents a specific column's metadata in the result.
230 Parsing is automatically done and the results are exported via public
231 attributes on the class such as: db, table_name, name, length, type_code.
232 """
234 def __init__(self, data, encoding):
235 MysqlPacket.__init__(self, data, encoding)
236 self._parse_field_descriptor(encoding)
238 def _parse_field_descriptor(self, encoding):
239 """Parse the 'Field Descriptor' (Metadata) packet.
241 This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
242 """
243 self.catalog = self.read_length_coded_string()
244 self.db = self.read_length_coded_string()
245 self.table_name = self.read_length_coded_string().decode(encoding)
246 self.org_table = self.read_length_coded_string().decode(encoding)
247 self.name = self.read_length_coded_string().decode(encoding)
248 self.org_name = self.read_length_coded_string().decode(encoding)
249 (
250 self.charsetnr,
251 self.length,
252 self.type_code,
253 self.flags,
254 self.scale,
255 ) = self.read_struct("<xHIBHBxx")
256 # 'default' is a length coded binary and is still in the buffer?
257 # not used for normal result sets...
259 def description(self):
260 """Provides a 7-item tuple compatible with the Python PEP249 DB Spec."""
261 return (
262 self.name,
263 self.type_code,
264 None, # TODO: display_length; should this be self.length?
265 self.get_column_length(), # 'internal_size'
266 self.get_column_length(), # 'precision' # TODO: why!?!?
267 self.scale,
268 self.flags % 2 == 0,
269 )
271 def get_column_length(self):
272 if self.type_code == FIELD_TYPE.VAR_STRING:
273 mblen = MBLENGTH.get(self.charsetnr, 1)
274 return self.length // mblen
275 return self.length
277 def __str__(self):
278 return "{} {!r}.{!r}.{!r}, type={}, flags={:x}".format(
279 self.__class__,
280 self.db,
281 self.table_name,
282 self.name,
283 self.type_code,
284 self.flags,
285 )
288class OKPacketWrapper:
289 """
290 OK Packet Wrapper. It uses an existing packet object, and wraps
291 around it, exposing useful variables while still providing access
292 to the original packet objects variables and methods.
293 """
295 def __init__(self, from_packet):
296 if not from_packet.is_ok_packet():
297 raise ValueError(
298 "Cannot create "
299 + str(self.__class__.__name__)
300 + " object from invalid packet type"
301 )
303 self.packet = from_packet
304 self.packet.advance(1)
306 self.affected_rows = self.packet.read_length_encoded_integer()
307 self.insert_id = self.packet.read_length_encoded_integer()
308 self.server_status, self.warning_count = self.read_struct("<HH")
309 self.message = self.packet.read_all()
310 self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
312 def __getattr__(self, key):
313 return getattr(self.packet, key)
316class EOFPacketWrapper:
317 """
318 EOF Packet Wrapper. It uses an existing packet object, and wraps
319 around it, exposing useful variables while still providing access
320 to the original packet objects variables and methods.
321 """
323 def __init__(self, from_packet):
324 if not from_packet.is_eof_packet():
325 raise ValueError(
326 f"Cannot create '{self.__class__}' object from invalid packet type"
327 )
329 self.packet = from_packet
330 self.warning_count, self.server_status = self.packet.read_struct("<xhh")
331 if DEBUG:
332 print("server_status=", self.server_status)
333 self.has_next = self.server_status & SERVER_STATUS.SERVER_MORE_RESULTS_EXISTS
335 def __getattr__(self, key):
336 return getattr(self.packet, key)
339class LoadLocalPacketWrapper:
340 """
341 Load Local Packet Wrapper. It uses an existing packet object, and wraps
342 around it, exposing useful variables while still providing access
343 to the original packet objects variables and methods.
344 """
346 def __init__(self, from_packet):
347 if not from_packet.is_load_local_packet():
348 raise ValueError(
349 f"Cannot create '{self.__class__}' object from invalid packet type"
350 )
352 self.packet = from_packet
353 self.filename = self.packet.get_all_data()[1:]
354 if DEBUG:
355 print("filename=", self.filename)
357 def __getattr__(self, key):
358 return getattr(self.packet, key)