1# Python implementation of the MySQL client-server protocol
2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
3# Error codes:
4# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html
5import errno
6import os
7import socket
8import struct
9import sys
10import traceback
11import warnings
12
13from . import _auth
14
15from .charset import charset_by_name, charset_by_id
16from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS
17from . import converters
18from .cursors import Cursor
19from .optionfile import Parser
20from .protocol import (
21 dump_packet,
22 MysqlPacket,
23 FieldDescriptorPacket,
24 OKPacketWrapper,
25 EOFPacketWrapper,
26 LoadLocalPacketWrapper,
27)
28from . import err, VERSION_STRING
29
30try:
31 import ssl
32
33 SSL_ENABLED = True
34except ImportError:
35 ssl = None
36 SSL_ENABLED = False
37
38try:
39 import getpass
40
41 DEFAULT_USER = getpass.getuser()
42 del getpass
43except (ImportError, KeyError, OSError):
44 # When there's no entry in OS database for a current user:
45 # KeyError is raised in Python 3.12 and below.
46 # OSError is raised in Python 3.13+
47 DEFAULT_USER = None
48
49DEBUG = False
50_DEFAULT_AUTH_PLUGIN = None # if this is not None, use it instead of server's default.
51
52TEXT_TYPES = {
53 FIELD_TYPE.BIT,
54 FIELD_TYPE.BLOB,
55 FIELD_TYPE.LONG_BLOB,
56 FIELD_TYPE.MEDIUM_BLOB,
57 FIELD_TYPE.STRING,
58 FIELD_TYPE.TINY_BLOB,
59 FIELD_TYPE.VAR_STRING,
60 FIELD_TYPE.VARCHAR,
61 FIELD_TYPE.GEOMETRY,
62}
63
64
65DEFAULT_CHARSET = "utf8mb4"
66
67MAX_PACKET_LEN = 2**24 - 1
68
69
70def _pack_int24(n):
71 return struct.pack("<I", n)[:3]
72
73
74# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
75def _lenenc_int(i):
76 if i < 0:
77 raise ValueError(
78 "Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i
79 )
80 elif i < 0xFB:
81 return bytes([i])
82 elif i < (1 << 16):
83 return b"\xfc" + struct.pack("<H", i)
84 elif i < (1 << 24):
85 return b"\xfd" + struct.pack("<I", i)[:3]
86 elif i < (1 << 64):
87 return b"\xfe" + struct.pack("<Q", i)
88 else:
89 raise ValueError(
90 f"Encoding {i:x} is larger than {1 << 64:x} - no representation in LengthEncodedInteger"
91 )
92
93
94class Connection:
95 """
96 Representation of a socket with a mysql server.
97
98 The proper way to get an instance of this class is to call
99 connect().
100
101 Establish a connection to the MySQL database. Accepts several
102 arguments:
103
104 :param host: Host where the database server is located.
105 :param user: Username to log in as.
106 :param password: Password to use.
107 :param database: Database to use, None to not use a particular one.
108 :param port: MySQL port to use, default is usually OK. (default: 3306)
109 :param bind_address: When the client has multiple network interfaces, specify
110 the interface from which to connect to the host. Argument can be
111 a hostname or an IP address.
112 :param unix_socket: Use a unix socket rather than TCP/IP.
113 :param read_timeout: The timeout for reading from the connection in seconds.
114 (default: None - no timeout)
115 :param write_timeout: The timeout for writing to the connection in seconds.
116 (default: None - no timeout)
117 :param str charset: Charset to use.
118 :param str collation: Collation name to use.
119 :param sql_mode: Default SQL_MODE to use.
120 :param read_default_file:
121 Specifies my.cnf file to read these parameters from under the [client] section.
122 :param conv:
123 Conversion dictionary to use instead of the default one.
124 This is used to provide custom marshalling and unmarshalling of types.
125 See converters.
126 :param use_unicode:
127 Whether or not to default to unicode strings.
128 This option defaults to true.
129 :param client_flag: Custom flags to send to MySQL. Find potential values in constants.CLIENT.
130 :param cursorclass: Custom cursor class to use.
131 :param init_command: Initial SQL statement to run when connection is established.
132 :param connect_timeout: The timeout for connecting to the database in seconds.
133 (default: 10, min: 1, max: 31536000)
134 :param ssl: A dict of arguments similar to mysql_ssl_set()'s parameters or an ssl.SSLContext.
135 :param ssl_ca: Path to the file that contains a PEM-formatted CA certificate.
136 :param ssl_cert: Path to the file that contains a PEM-formatted client certificate.
137 :param ssl_disabled: A boolean value that disables usage of TLS.
138 :param ssl_key: Path to the file that contains a PEM-formatted private key for
139 the client certificate.
140 :param ssl_key_password: The password for the client certificate private key.
141 :param ssl_verify_cert: Set to true to check the server certificate's validity.
142 :param ssl_verify_identity: Set to true to check the server's identity.
143 :param read_default_group: Group to read from in the configuration file.
144 :param autocommit: Autocommit mode. None means use server default. (default: False)
145 :param local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
146 :param max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
147 Only used to limit size of "LOAD LOCAL INFILE" data packet smaller than default (16KB).
148 :param defer_connect: Don't explicitly connect on construction - wait for connect call.
149 (default: False)
150 :param auth_plugin_map: A dict of plugin names to a class that processes that plugin.
151 The class will take the Connection object as the argument to the constructor.
152 The class needs an authenticate method taking an authentication packet as
153 an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
154 (if no authenticate method) for returning a string from the user. (experimental)
155 :param server_public_key: SHA256 authentication plugin public key value. (default: None)
156 :param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False)
157 :param compress: Not supported.
158 :param named_pipe: Not supported.
159 :param db: **DEPRECATED** Alias for database.
160 :param passwd: **DEPRECATED** Alias for password.
161
162 See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_ in the
163 specification.
164 """
165
166 _sock = None
167 _rfile = None
168 _auth_plugin_name = ""
169 _closed = False
170 _secure = False
171
172 def __init__(
173 self,
174 *,
175 user=None, # The first four arguments is based on DB-API 2.0 recommendation.
176 password="",
177 host=None,
178 database=None,
179 unix_socket=None,
180 port=0,
181 charset="",
182 collation=None,
183 sql_mode=None,
184 read_default_file=None,
185 conv=None,
186 use_unicode=True,
187 client_flag=0,
188 cursorclass=Cursor,
189 init_command=None,
190 connect_timeout=10,
191 read_default_group=None,
192 autocommit=False,
193 local_infile=False,
194 max_allowed_packet=16 * 1024 * 1024,
195 defer_connect=False,
196 auth_plugin_map=None,
197 read_timeout=None,
198 write_timeout=None,
199 bind_address=None,
200 binary_prefix=False,
201 program_name=None,
202 server_public_key=None,
203 ssl=None,
204 ssl_ca=None,
205 ssl_cert=None,
206 ssl_disabled=None,
207 ssl_key=None,
208 ssl_key_password=None,
209 ssl_verify_cert=None,
210 ssl_verify_identity=None,
211 compress=None, # not supported
212 named_pipe=None, # not supported
213 passwd=None, # deprecated
214 db=None, # deprecated
215 ):
216 if db is not None and database is None:
217 # We will raise warning in 2022 or later.
218 # See https://github.com/PyMySQL/PyMySQL/issues/939
219 # warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)
220 database = db
221 if passwd is not None and not password:
222 # We will raise warning in 2022 or later.
223 # See https://github.com/PyMySQL/PyMySQL/issues/939
224 # warnings.warn(
225 # "'passwd' is deprecated, use 'password'", DeprecationWarning, 3
226 # )
227 password = passwd
228
229 if compress or named_pipe:
230 raise NotImplementedError(
231 "compress and named_pipe arguments are not supported"
232 )
233
234 self._local_infile = bool(local_infile)
235 if self._local_infile:
236 client_flag |= CLIENT.LOCAL_FILES
237
238 if read_default_group and not read_default_file:
239 if sys.platform.startswith("win"):
240 read_default_file = "c:\\my.ini"
241 else:
242 read_default_file = "/etc/my.cnf"
243
244 if read_default_file:
245 if not read_default_group:
246 read_default_group = "client"
247
248 cfg = Parser()
249 cfg.read(os.path.expanduser(read_default_file))
250
251 def _config(key, arg):
252 if arg:
253 return arg
254 try:
255 return cfg.get(read_default_group, key)
256 except Exception:
257 return arg
258
259 user = _config("user", user)
260 password = _config("password", password)
261 host = _config("host", host)
262 database = _config("database", database)
263 unix_socket = _config("socket", unix_socket)
264 port = int(_config("port", port))
265 bind_address = _config("bind-address", bind_address)
266 charset = _config("default-character-set", charset)
267 if not ssl:
268 ssl = {}
269 if isinstance(ssl, dict):
270 for key in ["ca", "capath", "cert", "key", "password", "cipher"]:
271 value = _config("ssl-" + key, ssl.get(key))
272 if value:
273 ssl[key] = value
274
275 self.ssl = False
276 if not ssl_disabled:
277 if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity:
278 ssl = {
279 "ca": ssl_ca,
280 "check_hostname": bool(ssl_verify_identity),
281 "verify_mode": ssl_verify_cert
282 if ssl_verify_cert is not None
283 else False,
284 }
285 if ssl_cert is not None:
286 ssl["cert"] = ssl_cert
287 if ssl_key is not None:
288 ssl["key"] = ssl_key
289 if ssl_key_password is not None:
290 ssl["password"] = ssl_key_password
291 if ssl:
292 if not SSL_ENABLED:
293 raise NotImplementedError("ssl module not found")
294 self.ssl = True
295 client_flag |= CLIENT.SSL
296 self.ctx = self._create_ssl_ctx(ssl)
297
298 self.host = host or "localhost"
299 self.port = port or 3306
300 if type(self.port) is not int:
301 raise ValueError("port should be of type int")
302 self.user = user or DEFAULT_USER
303 self.password = password or b""
304 if isinstance(self.password, str):
305 self.password = self.password.encode("latin1")
306 self.db = database
307 self.unix_socket = unix_socket
308 self.bind_address = bind_address
309 if not (0 < connect_timeout <= 31536000):
310 raise ValueError("connect_timeout should be >0 and <=31536000")
311 self.connect_timeout = connect_timeout or None
312 if read_timeout is not None and read_timeout <= 0:
313 raise ValueError("read_timeout should be > 0")
314 self._read_timeout = read_timeout
315 if write_timeout is not None and write_timeout <= 0:
316 raise ValueError("write_timeout should be > 0")
317 self._write_timeout = write_timeout
318
319 self.charset = charset or DEFAULT_CHARSET
320 self.collation = collation
321 self.use_unicode = use_unicode
322
323 self.encoding = charset_by_name(self.charset).encoding
324
325 client_flag |= CLIENT.CAPABILITIES
326 if self.db:
327 client_flag |= CLIENT.CONNECT_WITH_DB
328
329 self.client_flag = client_flag
330
331 self.cursorclass = cursorclass
332
333 self._result = None
334 self._affected_rows = 0
335 self.host_info = "Not connected"
336
337 # specified autocommit mode. None means use server default.
338 self.autocommit_mode = autocommit
339
340 if conv is None:
341 conv = converters.conversions
342
343 # Need for MySQLdb compatibility.
344 self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}
345 self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}
346 self.sql_mode = sql_mode
347 self.init_command = init_command
348 self.max_allowed_packet = max_allowed_packet
349 self._auth_plugin_map = auth_plugin_map or {}
350 self._binary_prefix = binary_prefix
351 self.server_public_key = server_public_key
352
353 self._connect_attrs = {
354 "_client_name": "pymysql",
355 "_client_version": VERSION_STRING,
356 "_pid": str(os.getpid()),
357 }
358
359 if program_name:
360 self._connect_attrs["program_name"] = program_name
361
362 if defer_connect:
363 self._sock = None
364 else:
365 self.connect()
366
367 def __enter__(self):
368 return self
369
370 def __exit__(self, *exc_info):
371 del exc_info
372 self.close()
373
374 def _create_ssl_ctx(self, sslp):
375 if isinstance(sslp, ssl.SSLContext):
376 return sslp
377 ca = sslp.get("ca")
378 capath = sslp.get("capath")
379 hasnoca = ca is None and capath is None
380 ctx = ssl.create_default_context(cafile=ca, capath=capath)
381
382 # Python 3.13 enables VERIFY_X509_STRICT by default.
383 # But self signed certificates that are generated by MySQL automatically
384 # doesn't pass the verification.
385 ctx.verify_flags &= ~ssl.VERIFY_X509_STRICT
386
387 ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True)
388 verify_mode_value = sslp.get("verify_mode")
389 if verify_mode_value is None:
390 ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
391 elif isinstance(verify_mode_value, bool):
392 ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE
393 else:
394 if isinstance(verify_mode_value, str):
395 verify_mode_value = verify_mode_value.lower()
396 if verify_mode_value in ("none", "0", "false", "no"):
397 ctx.verify_mode = ssl.CERT_NONE
398 elif verify_mode_value == "optional":
399 ctx.verify_mode = ssl.CERT_OPTIONAL
400 elif verify_mode_value in ("required", "1", "true", "yes"):
401 ctx.verify_mode = ssl.CERT_REQUIRED
402 else:
403 ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
404 if "cert" in sslp:
405 ctx.load_cert_chain(
406 sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password")
407 )
408 if "cipher" in sslp:
409 ctx.set_ciphers(sslp["cipher"])
410 ctx.options |= ssl.OP_NO_SSLv2
411 ctx.options |= ssl.OP_NO_SSLv3
412 return ctx
413
414 def close(self):
415 """
416 Send the quit message and close the socket.
417
418 See `Connection.close() <https://www.python.org/dev/peps/pep-0249/#Connection.close>`_
419 in the specification.
420
421 :raise Error: If the connection is already closed.
422 """
423 if self._closed:
424 raise err.Error("Already closed")
425 self._closed = True
426 if self._sock is None:
427 return
428 send_data = struct.pack("<iB", 1, COMMAND.COM_QUIT)
429 try:
430 self._write_bytes(send_data)
431 except Exception:
432 pass
433 finally:
434 self._force_close()
435
436 @property
437 def open(self):
438 """Return True if the connection is open."""
439 return self._sock is not None
440
441 def _force_close(self):
442 """Close connection without QUIT message."""
443 if self._rfile:
444 self._rfile.close()
445 if self._sock:
446 try:
447 self._sock.close()
448 except: # noqa
449 pass
450 self._sock = None
451 self._rfile = None
452
453 __del__ = _force_close
454
455 def autocommit(self, value):
456 self.autocommit_mode = bool(value)
457 current = self.get_autocommit()
458 if value != current:
459 self._send_autocommit_mode()
460
461 def get_autocommit(self):
462 return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)
463
464 def _read_ok_packet(self):
465 pkt = self._read_packet()
466 if not pkt.is_ok_packet():
467 raise err.OperationalError(
468 CR.CR_COMMANDS_OUT_OF_SYNC,
469 "Command Out of Sync",
470 )
471 ok = OKPacketWrapper(pkt)
472 self.server_status = ok.server_status
473 return ok
474
475 def _send_autocommit_mode(self):
476 """Set whether or not to commit after every execute()."""
477 self._execute_command(
478 COMMAND.COM_QUERY, "SET AUTOCOMMIT = %s" % self.escape(self.autocommit_mode)
479 )
480 self._read_ok_packet()
481
482 def begin(self):
483 """Begin transaction."""
484 self._execute_command(COMMAND.COM_QUERY, "BEGIN")
485 self._read_ok_packet()
486
487 def commit(self):
488 """
489 Commit changes to stable storage.
490
491 See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_
492 in the specification.
493 """
494 self._execute_command(COMMAND.COM_QUERY, "COMMIT")
495 self._read_ok_packet()
496
497 def rollback(self):
498 """
499 Roll back the current transaction.
500
501 See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_
502 in the specification.
503 """
504 self._execute_command(COMMAND.COM_QUERY, "ROLLBACK")
505 self._read_ok_packet()
506
507 def show_warnings(self):
508 """Send the "SHOW WARNINGS" SQL command."""
509 self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS")
510 result = MySQLResult(self)
511 result.read()
512 return result.rows
513
514 def select_db(self, db):
515 """
516 Set current db.
517
518 :param db: The name of the db.
519 """
520 self._execute_command(COMMAND.COM_INIT_DB, db)
521 self._read_ok_packet()
522
523 def escape(self, obj, mapping=None):
524 """Escape whatever value is passed.
525
526 Non-standard, for internal use; do not use this in your applications.
527 """
528 if isinstance(obj, str):
529 return "'" + self.escape_string(obj) + "'"
530 if isinstance(obj, (bytes, bytearray)):
531 ret = self._quote_bytes(obj)
532 if self._binary_prefix:
533 ret = "_binary" + ret
534 return ret
535 return converters.escape_item(obj, self.charset, mapping=mapping)
536
537 def literal(self, obj):
538 """Alias for escape().
539
540 Non-standard, for internal use; do not use this in your applications.
541 """
542 return self.escape(obj, self.encoders)
543
544 def escape_string(self, s):
545 if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
546 return s.replace("'", "''")
547 return converters.escape_string(s)
548
549 def _quote_bytes(self, s):
550 if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
551 return "'{}'".format(
552 s.replace(b"'", b"''").decode("ascii", "surrogateescape")
553 )
554 return converters.escape_bytes(s)
555
556 def cursor(self, cursor=None):
557 """
558 Create a new cursor to execute queries with.
559
560 :param cursor: The type of cursor to create. None means use Cursor.
561 :type cursor: :py:class:`Cursor`, :py:class:`SSCursor`, :py:class:`DictCursor`,
562 or :py:class:`SSDictCursor`.
563 """
564 if cursor:
565 return cursor(self)
566 return self.cursorclass(self)
567
568 # The following methods are INTERNAL USE ONLY (called from Cursor)
569 def query(self, sql, unbuffered=False):
570 # if DEBUG:
571 # print("DEBUG: sending query:", sql)
572 if isinstance(sql, str):
573 sql = sql.encode(self.encoding, "surrogateescape")
574 self._execute_command(COMMAND.COM_QUERY, sql)
575 self._affected_rows = self._read_query_result(unbuffered=unbuffered)
576 return self._affected_rows
577
578 def next_result(self, unbuffered=False):
579 self._affected_rows = self._read_query_result(unbuffered=unbuffered)
580 return self._affected_rows
581
582 def affected_rows(self):
583 return self._affected_rows
584
585 def kill(self, thread_id):
586 if not isinstance(thread_id, int):
587 raise TypeError("thread_id must be an integer")
588 self.query(f"KILL {thread_id:d}")
589
590 def ping(self, reconnect=True):
591 """
592 Check if the server is alive.
593
594 :param reconnect: If the connection is closed, reconnect.
595 :type reconnect: boolean
596
597 :raise Error: If the connection is closed and reconnect=False.
598 """
599 if self._sock is None:
600 if reconnect:
601 self.connect()
602 reconnect = False
603 else:
604 raise err.Error("Already closed")
605 try:
606 self._execute_command(COMMAND.COM_PING, "")
607 self._read_ok_packet()
608 except Exception:
609 if reconnect:
610 self.connect()
611 self.ping(False)
612 else:
613 raise
614
615 def set_charset(self, charset):
616 """Deprecated. Use set_character_set() instead."""
617 # This function has been implemented in old PyMySQL.
618 # But this name is different from MySQLdb.
619 # So we keep this function for compatibility and add
620 # new set_character_set() function.
621 self.set_character_set(charset)
622
623 def set_character_set(self, charset, collation=None):
624 """
625 Set charaset (and collation)
626
627 Send "SET NAMES charset [COLLATE collation]" query.
628 Update Connection.encoding based on charset.
629 """
630 # Make sure charset is supported.
631 encoding = charset_by_name(charset).encoding
632
633 if collation:
634 query = f"SET NAMES {charset} COLLATE {collation}"
635 else:
636 query = f"SET NAMES {charset}"
637 self._execute_command(COMMAND.COM_QUERY, query)
638 self._read_packet()
639 self.charset = charset
640 self.encoding = encoding
641 self.collation = collation
642
643 def connect(self, sock=None):
644 self._closed = False
645 try:
646 if sock is None:
647 if self.unix_socket:
648 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
649 sock.settimeout(self.connect_timeout)
650 sock.connect(self.unix_socket)
651 self.host_info = "Localhost via UNIX socket"
652 self._secure = True
653 if DEBUG:
654 print("connected using unix_socket")
655 else:
656 kwargs = {}
657 if self.bind_address is not None:
658 kwargs["source_address"] = (self.bind_address, 0)
659 while True:
660 try:
661 sock = socket.create_connection(
662 (self.host, self.port), self.connect_timeout, **kwargs
663 )
664 break
665 except OSError as e:
666 if e.errno == errno.EINTR:
667 continue
668 raise
669 self.host_info = "socket %s:%d" % (self.host, self.port)
670 if DEBUG:
671 print("connected using socket")
672 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
673 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
674 sock.settimeout(None)
675
676 self._sock = sock
677 self._rfile = sock.makefile("rb")
678 self._next_seq_id = 0
679
680 self._get_server_information()
681 self._request_authentication()
682
683 # Send "SET NAMES" query on init for:
684 # - Ensure charaset (and collation) is set to the server.
685 # - collation_id in handshake packet may be ignored.
686 # - If collation is not specified, we don't know what is server's
687 # default collation for the charset. For example, default collation
688 # of utf8mb4 is:
689 # - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci
690 # - MySQL 8.0: utf8mb4_0900_ai_ci
691 #
692 # Reference:
693 # - https://github.com/PyMySQL/PyMySQL/issues/1092
694 # - https://github.com/wagtail/wagtail/issues/9477
695 # - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese)
696 self.set_character_set(self.charset, self.collation)
697
698 if self.sql_mode is not None:
699 c = self.cursor()
700 c.execute("SET sql_mode=%s", (self.sql_mode,))
701 c.close()
702
703 if self.init_command is not None:
704 c = self.cursor()
705 c.execute(self.init_command)
706 c.close()
707
708 if self.autocommit_mode is not None:
709 self.autocommit(self.autocommit_mode)
710 except BaseException as e:
711 self._force_close()
712
713 if isinstance(e, (OSError, IOError)):
714 exc = err.OperationalError(
715 CR.CR_CONN_HOST_ERROR,
716 f"Can't connect to MySQL server on {self.host!r} ({e})",
717 )
718 # Keep original exception and traceback to investigate error.
719 exc.original_exception = e
720 exc.traceback = traceback.format_exc()
721 if DEBUG:
722 print(exc.traceback)
723 raise exc
724
725 # If e is neither DatabaseError or IOError, It's a bug.
726 # But raising AssertionError hides original error.
727 # So just reraise it.
728 raise
729
730 def write_packet(self, payload):
731 """Writes an entire "mysql packet" in its entirety to the network
732 adding its length and sequence number.
733 """
734 # Internal note: when you build packet manually and calls _write_bytes()
735 # directly, you should set self._next_seq_id properly.
736 data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload
737 if DEBUG:
738 dump_packet(data)
739 self._write_bytes(data)
740 self._next_seq_id = (self._next_seq_id + 1) % 256
741
742 def _read_packet(self, packet_type=MysqlPacket):
743 """Read an entire "mysql packet" in its entirety from the network
744 and return a MysqlPacket type that represents the results.
745
746 :raise OperationalError: If the connection to the MySQL server is lost.
747 :raise InternalError: If the packet sequence number is wrong.
748 """
749 buff = bytearray()
750 while True:
751 packet_header = self._read_bytes(4)
752 # if DEBUG: dump_packet(packet_header)
753
754 btrl, btrh, packet_number = struct.unpack("<HBB", packet_header)
755 bytes_to_read = btrl + (btrh << 16)
756 if packet_number != self._next_seq_id:
757 self._force_close()
758 if packet_number == 0:
759 # MariaDB sends error packet with seqno==0 when shutdown
760 raise err.OperationalError(
761 CR.CR_SERVER_LOST,
762 "Lost connection to MySQL server during query",
763 )
764 raise err.InternalError(
765 "Packet sequence number wrong - got %d expected %d"
766 % (packet_number, self._next_seq_id)
767 )
768 self._next_seq_id = (self._next_seq_id + 1) % 256
769
770 recv_data = self._read_bytes(bytes_to_read)
771 if DEBUG:
772 dump_packet(recv_data)
773 buff += recv_data
774 # https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
775 if bytes_to_read < MAX_PACKET_LEN:
776 break
777
778 packet = packet_type(bytes(buff), self.encoding)
779 if packet.is_error_packet():
780 if self._result is not None and self._result.unbuffered_active is True:
781 self._result.unbuffered_active = False
782 packet.raise_for_error()
783 return packet
784
785 def _read_bytes(self, num_bytes):
786 self._sock.settimeout(self._read_timeout)
787 while True:
788 try:
789 data = self._rfile.read(num_bytes)
790 break
791 except OSError as e:
792 if e.errno == errno.EINTR:
793 continue
794 self._force_close()
795 raise err.OperationalError(
796 CR.CR_SERVER_LOST,
797 f"Lost connection to MySQL server during query ({e})",
798 )
799 except BaseException:
800 # Don't convert unknown exception to MySQLError.
801 self._force_close()
802 raise
803 if len(data) < num_bytes:
804 self._force_close()
805 raise err.OperationalError(
806 CR.CR_SERVER_LOST, "Lost connection to MySQL server during query"
807 )
808 return data
809
810 def _write_bytes(self, data):
811 self._sock.settimeout(self._write_timeout)
812 try:
813 self._sock.sendall(data)
814 except OSError as e:
815 self._force_close()
816 raise err.OperationalError(
817 CR.CR_SERVER_GONE_ERROR, f"MySQL server has gone away ({e!r})"
818 )
819
820 def _read_query_result(self, unbuffered=False):
821 self._result = None
822 result = MySQLResult(self)
823 if unbuffered:
824 result.init_unbuffered_query()
825 else:
826 result.read()
827 self._result = result
828 if result.server_status is not None:
829 self.server_status = result.server_status
830 return result.affected_rows
831
832 def insert_id(self):
833 if self._result:
834 return self._result.insert_id
835 else:
836 return 0
837
838 def _execute_command(self, command, sql):
839 """
840 :raise InterfaceError: If the connection is closed.
841 :raise ValueError: If no username was specified.
842 """
843 if not self._sock:
844 raise err.InterfaceError(0, "")
845
846 # If the last query was unbuffered, make sure it finishes before
847 # sending new commands
848 if self._result is not None:
849 if self._result.unbuffered_active:
850 warnings.warn("Previous unbuffered result was left incomplete")
851 self._result._finish_unbuffered_query()
852 while self._result.has_next:
853 self.next_result()
854 self._result = None
855
856 if isinstance(sql, str):
857 sql = sql.encode(self.encoding)
858
859 packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
860
861 # tiny optimization: build first packet manually instead of
862 # calling self..write_packet()
863 prelude = struct.pack("<iB", packet_size, command)
864 packet = prelude + sql[: packet_size - 1]
865 self._write_bytes(packet)
866 if DEBUG:
867 dump_packet(packet)
868 self._next_seq_id = 1
869
870 if packet_size < MAX_PACKET_LEN:
871 return
872
873 sql = sql[packet_size - 1 :]
874 while True:
875 packet_size = min(MAX_PACKET_LEN, len(sql))
876 self.write_packet(sql[:packet_size])
877 sql = sql[packet_size:]
878 if not sql and packet_size < MAX_PACKET_LEN:
879 break
880
881 def _request_authentication(self):
882 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
883 if int(self.server_version.split(".", 1)[0]) >= 5:
884 self.client_flag |= CLIENT.MULTI_RESULTS
885
886 if self.user is None:
887 raise ValueError("Did not specify a username")
888
889 charset_id = charset_by_name(self.charset).id
890 if isinstance(self.user, str):
891 self.user = self.user.encode(self.encoding)
892
893 data_init = struct.pack(
894 "<iIB23s", self.client_flag, MAX_PACKET_LEN, charset_id, b""
895 )
896
897 if self.ssl and self.server_capabilities & CLIENT.SSL:
898 self.write_packet(data_init)
899
900 self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host)
901 self._rfile = self._sock.makefile("rb")
902 self._secure = True
903
904 data = data_init + self.user + b"\0"
905
906 authresp = b""
907 plugin_name = None
908
909 if self._auth_plugin_name == "":
910 plugin_name = b""
911 authresp = _auth.scramble_native_password(self.password, self.salt)
912 elif self._auth_plugin_name == "mysql_native_password":
913 plugin_name = b"mysql_native_password"
914 authresp = _auth.scramble_native_password(self.password, self.salt)
915 elif self._auth_plugin_name == "caching_sha2_password":
916 plugin_name = b"caching_sha2_password"
917 if self.password:
918 if DEBUG:
919 print("caching_sha2: trying fast path")
920 authresp = _auth.scramble_caching_sha2(self.password, self.salt)
921 else:
922 if DEBUG:
923 print("caching_sha2: empty password")
924 elif self._auth_plugin_name == "sha256_password":
925 plugin_name = b"sha256_password"
926 if self.ssl and self.server_capabilities & CLIENT.SSL:
927 authresp = self.password + b"\0"
928 elif self.password:
929 authresp = b"\1" # request public key
930 else:
931 authresp = b"\0" # empty password
932
933 if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
934 data += _lenenc_int(len(authresp)) + authresp
935 elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
936 data += struct.pack("B", len(authresp)) + authresp
937 else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
938 data += authresp + b"\0"
939
940 if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
941 if isinstance(self.db, str):
942 self.db = self.db.encode(self.encoding)
943 data += self.db + b"\0"
944
945 if self.server_capabilities & CLIENT.PLUGIN_AUTH:
946 data += (plugin_name or b"") + b"\0"
947
948 if self.server_capabilities & CLIENT.CONNECT_ATTRS:
949 connect_attrs = b""
950 for k, v in self._connect_attrs.items():
951 k = k.encode("utf-8")
952 connect_attrs += _lenenc_int(len(k)) + k
953 v = v.encode("utf-8")
954 connect_attrs += _lenenc_int(len(v)) + v
955 data += _lenenc_int(len(connect_attrs)) + connect_attrs
956
957 self.write_packet(data)
958 auth_packet = self._read_packet()
959
960 # if authentication method isn't accepted the first byte
961 # will have the octet 254
962 if auth_packet.is_auth_switch_request():
963 if DEBUG:
964 print("received auth switch")
965 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
966 auth_packet.read_uint8() # 0xfe packet identifier
967 plugin_name = auth_packet.read_string()
968 if (
969 self.server_capabilities & CLIENT.PLUGIN_AUTH
970 and plugin_name is not None
971 ):
972 auth_packet = self._process_auth(plugin_name, auth_packet)
973 else:
974 raise err.OperationalError("received unknown auth switch request")
975 elif auth_packet.is_extra_auth_data():
976 if DEBUG:
977 print("received extra data")
978 # https://dev.mysql.com/doc/internals/en/successful-authentication.html
979 if self._auth_plugin_name == "caching_sha2_password":
980 auth_packet = _auth.caching_sha2_password_auth(self, auth_packet)
981 elif self._auth_plugin_name == "sha256_password":
982 auth_packet = _auth.sha256_password_auth(self, auth_packet)
983 else:
984 raise err.OperationalError(
985 "Received extra packet for auth method %r", self._auth_plugin_name
986 )
987
988 if DEBUG:
989 print("Succeed to auth")
990
991 def _process_auth(self, plugin_name, auth_packet):
992 handler = self._get_auth_plugin_handler(plugin_name)
993 if handler:
994 try:
995 return handler.authenticate(auth_packet)
996 except AttributeError:
997 if plugin_name != b"dialog":
998 raise err.OperationalError(
999 CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1000 f"Authentication plugin '{plugin_name}'"
1001 f" not loaded: - {type(handler)!r} missing authenticate method",
1002 )
1003 if plugin_name == b"caching_sha2_password":
1004 return _auth.caching_sha2_password_auth(self, auth_packet)
1005 elif plugin_name == b"sha256_password":
1006 return _auth.sha256_password_auth(self, auth_packet)
1007 elif plugin_name == b"mysql_native_password":
1008 data = _auth.scramble_native_password(self.password, auth_packet.read_all())
1009 elif plugin_name == b"client_ed25519":
1010 data = _auth.ed25519_password(self.password, auth_packet.read_all())
1011 elif plugin_name == b"mysql_old_password":
1012 data = (
1013 _auth.scramble_old_password(self.password, auth_packet.read_all())
1014 + b"\0"
1015 )
1016 elif plugin_name == b"mysql_clear_password":
1017 # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
1018 data = self.password + b"\0"
1019 elif plugin_name == b"dialog":
1020 pkt = auth_packet
1021 while True:
1022 flag = pkt.read_uint8()
1023 echo = (flag & 0x06) == 0x02
1024 last = (flag & 0x01) == 0x01
1025 prompt = pkt.read_all()
1026
1027 if prompt == b"Password: ":
1028 self.write_packet(self.password + b"\0")
1029 elif handler:
1030 resp = "no response - TypeError within plugin.prompt method"
1031 try:
1032 resp = handler.prompt(echo, prompt)
1033 self.write_packet(resp + b"\0")
1034 except AttributeError:
1035 raise err.OperationalError(
1036 CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1037 f"Authentication plugin '{plugin_name}'"
1038 f" not loaded: - {handler!r} missing prompt method",
1039 )
1040 except TypeError:
1041 raise err.OperationalError(
1042 CR.CR_AUTH_PLUGIN_ERR,
1043 f"Authentication plugin '{plugin_name}'"
1044 f" {handler!r} didn't respond with string. Returned '{resp!r}' to prompt {prompt!r}",
1045 )
1046 else:
1047 raise err.OperationalError(
1048 CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1049 f"Authentication plugin '{plugin_name}' not configured",
1050 )
1051 pkt = self._read_packet()
1052 pkt.check_error()
1053 if pkt.is_ok_packet() or last:
1054 break
1055 return pkt
1056 else:
1057 raise err.OperationalError(
1058 CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1059 "Authentication plugin '%s' not configured" % plugin_name,
1060 )
1061
1062 self.write_packet(data)
1063 pkt = self._read_packet()
1064 pkt.check_error()
1065 return pkt
1066
1067 def _get_auth_plugin_handler(self, plugin_name):
1068 plugin_class = self._auth_plugin_map.get(plugin_name)
1069 if not plugin_class and isinstance(plugin_name, bytes):
1070 plugin_class = self._auth_plugin_map.get(plugin_name.decode("ascii"))
1071 if plugin_class:
1072 try:
1073 handler = plugin_class(self)
1074 except TypeError:
1075 raise err.OperationalError(
1076 CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1077 f"Authentication plugin '{plugin_name}'"
1078 f" not loaded: - {plugin_class!r} cannot be constructed with connection object",
1079 )
1080 else:
1081 handler = None
1082 return handler
1083
1084 # _mysql support
1085 def thread_id(self):
1086 return self.server_thread_id[0]
1087
1088 def character_set_name(self):
1089 return self.charset
1090
1091 def get_host_info(self):
1092 return self.host_info
1093
1094 def get_proto_info(self):
1095 return self.protocol_version
1096
1097 def _get_server_information(self):
1098 i = 0
1099 packet = self._read_packet()
1100 data = packet.get_all_data()
1101
1102 self.protocol_version = data[i]
1103 i += 1
1104
1105 server_end = data.find(b"\0", i)
1106 self.server_version = data[i:server_end].decode("latin1")
1107 i = server_end + 1
1108
1109 self.server_thread_id = struct.unpack("<I", data[i : i + 4])
1110 i += 4
1111
1112 self.salt = data[i : i + 8]
1113 i += 9 # 8 + 1(filler)
1114
1115 self.server_capabilities = struct.unpack("<H", data[i : i + 2])[0]
1116 i += 2
1117
1118 if len(data) >= i + 6:
1119 lang, stat, cap_h, salt_len = struct.unpack("<BHHB", data[i : i + 6])
1120 i += 6
1121 # TODO: deprecate server_language and server_charset.
1122 # mysqlclient-python doesn't provide it.
1123 self.server_language = lang
1124 try:
1125 self.server_charset = charset_by_id(lang).name
1126 except KeyError:
1127 # unknown collation
1128 self.server_charset = None
1129
1130 self.server_status = stat
1131 if DEBUG:
1132 print("server_status: %x" % stat)
1133
1134 self.server_capabilities |= cap_h << 16
1135 if DEBUG:
1136 print("salt_len:", salt_len)
1137 salt_len = max(12, salt_len - 9)
1138
1139 # reserved
1140 i += 10
1141
1142 if len(data) >= i + salt_len:
1143 # salt_len includes auth_plugin_data_part_1 and filler
1144 self.salt += data[i : i + salt_len]
1145 i += salt_len
1146
1147 i += 1
1148 # AUTH PLUGIN NAME may appear here.
1149 if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
1150 # Due to Bug#59453 the auth-plugin-name is missing the terminating
1151 # NUL-char in versions prior to 5.5.10 and 5.6.2.
1152 # ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
1153 # didn't use version checks as mariadb is corrected and reports
1154 # earlier than those two.
1155 server_end = data.find(b"\0", i)
1156 if server_end < 0: # pragma: no cover - very specific upstream bug
1157 # not found \0 and last field so take it all
1158 self._auth_plugin_name = data[i:].decode("utf-8")
1159 else:
1160 self._auth_plugin_name = data[i:server_end].decode("utf-8")
1161
1162 if _DEFAULT_AUTH_PLUGIN is not None: # for tests
1163 self._auth_plugin_name = _DEFAULT_AUTH_PLUGIN
1164
1165 def get_server_info(self):
1166 return self.server_version
1167
1168 Warning = err.Warning
1169 Error = err.Error
1170 InterfaceError = err.InterfaceError
1171 DatabaseError = err.DatabaseError
1172 DataError = err.DataError
1173 OperationalError = err.OperationalError
1174 IntegrityError = err.IntegrityError
1175 InternalError = err.InternalError
1176 ProgrammingError = err.ProgrammingError
1177 NotSupportedError = err.NotSupportedError
1178
1179
1180class MySQLResult:
1181 def __init__(self, connection):
1182 """
1183 :type connection: Connection
1184 """
1185 self.connection = connection
1186 self.affected_rows = None
1187 self.insert_id = None
1188 self.server_status = None
1189 self.warning_count = 0
1190 self.message = None
1191 self.field_count = 0
1192 self.description = None
1193 self.rows = None
1194 self.has_next = None
1195 self.unbuffered_active = False
1196
1197 def __del__(self):
1198 if self.unbuffered_active:
1199 self._finish_unbuffered_query()
1200
1201 def read(self):
1202 try:
1203 first_packet = self.connection._read_packet()
1204
1205 if first_packet.is_ok_packet():
1206 self._read_ok_packet(first_packet)
1207 elif first_packet.is_load_local_packet():
1208 self._read_load_local_packet(first_packet)
1209 else:
1210 self._read_result_packet(first_packet)
1211 finally:
1212 self.connection = None
1213
1214 def init_unbuffered_query(self):
1215 """
1216 :raise OperationalError: If the connection to the MySQL server is lost.
1217 :raise InternalError:
1218 """
1219 first_packet = self.connection._read_packet()
1220
1221 if first_packet.is_ok_packet():
1222 self.connection = None
1223 self._read_ok_packet(first_packet)
1224 elif first_packet.is_load_local_packet():
1225 try:
1226 self._read_load_local_packet(first_packet)
1227 finally:
1228 self.connection = None
1229 else:
1230 self.field_count = first_packet.read_length_encoded_integer()
1231 self._get_descriptions()
1232
1233 # Apparently, MySQLdb picks this number because it's the maximum
1234 # value of a 64bit unsigned integer. Since we're emulating MySQLdb,
1235 # we set it to this instead of None, which would be preferred.
1236 self.affected_rows = 18446744073709551615
1237 self.unbuffered_active = True
1238
1239 def _read_ok_packet(self, first_packet):
1240 ok_packet = OKPacketWrapper(first_packet)
1241 self.affected_rows = ok_packet.affected_rows
1242 self.insert_id = ok_packet.insert_id
1243 self.server_status = ok_packet.server_status
1244 self.warning_count = ok_packet.warning_count
1245 self.message = ok_packet.message
1246 self.has_next = ok_packet.has_next
1247
1248 def _read_load_local_packet(self, first_packet):
1249 if not self.connection._local_infile:
1250 raise RuntimeError(
1251 "**WARN**: Received LOAD_LOCAL packet but local_infile option is false."
1252 )
1253 load_packet = LoadLocalPacketWrapper(first_packet)
1254 sender = LoadLocalFile(load_packet.filename, self.connection)
1255 try:
1256 sender.send_data()
1257 except:
1258 self.connection._read_packet() # skip ok packet
1259 raise
1260
1261 ok_packet = self.connection._read_packet()
1262 if (
1263 not ok_packet.is_ok_packet()
1264 ): # pragma: no cover - upstream induced protocol error
1265 raise err.OperationalError(
1266 CR.CR_COMMANDS_OUT_OF_SYNC,
1267 "Commands Out of Sync",
1268 )
1269 self._read_ok_packet(ok_packet)
1270
1271 def _check_packet_is_eof(self, packet):
1272 if not packet.is_eof_packet():
1273 return False
1274 # TODO: Support CLIENT.DEPRECATE_EOF
1275 # 1) Add DEPRECATE_EOF to CAPABILITIES
1276 # 2) Mask CAPABILITIES with server_capabilities
1277 # 3) if server_capabilities & CLIENT.DEPRECATE_EOF:
1278 # use OKPacketWrapper instead of EOFPacketWrapper
1279 wp = EOFPacketWrapper(packet)
1280 self.warning_count = wp.warning_count
1281 self.has_next = wp.has_next
1282 return True
1283
1284 def _read_result_packet(self, first_packet):
1285 self.field_count = first_packet.read_length_encoded_integer()
1286 self._get_descriptions()
1287 self._read_rowdata_packet()
1288
1289 def _read_rowdata_packet_unbuffered(self):
1290 # Check if in an active query
1291 if not self.unbuffered_active:
1292 return
1293
1294 # EOF
1295 packet = self.connection._read_packet()
1296 if self._check_packet_is_eof(packet):
1297 self.unbuffered_active = False
1298 self.connection = None
1299 self.rows = None
1300 return
1301
1302 row = self._read_row_from_packet(packet)
1303 self.affected_rows = 1
1304 self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.
1305 return row
1306
1307 def _finish_unbuffered_query(self):
1308 # After much reading on the MySQL protocol, it appears that there is,
1309 # in fact, no way to stop MySQL from sending all the data after
1310 # executing a query, so we just spin, and wait for an EOF packet.
1311 while self.unbuffered_active:
1312 try:
1313 packet = self.connection._read_packet()
1314 except err.OperationalError as e:
1315 if e.args[0] in (
1316 ER.QUERY_TIMEOUT,
1317 ER.STATEMENT_TIMEOUT,
1318 ):
1319 # if the query timed out we can simply ignore this error
1320 self.unbuffered_active = False
1321 self.connection = None
1322 return
1323
1324 raise
1325
1326 if self._check_packet_is_eof(packet):
1327 self.unbuffered_active = False
1328 self.connection = None # release reference to kill cyclic reference.
1329
1330 def _read_rowdata_packet(self):
1331 """Read a rowdata packet for each data row in the result set."""
1332 rows = []
1333 while True:
1334 packet = self.connection._read_packet()
1335 if self._check_packet_is_eof(packet):
1336 self.connection = None # release reference to kill cyclic reference.
1337 break
1338 rows.append(self._read_row_from_packet(packet))
1339
1340 self.affected_rows = len(rows)
1341 self.rows = tuple(rows)
1342
1343 def _read_row_from_packet(self, packet):
1344 row = []
1345 for encoding, converter in self.converters:
1346 try:
1347 data = packet.read_length_coded_string()
1348 except IndexError:
1349 # No more columns in this row
1350 # See https://github.com/PyMySQL/PyMySQL/pull/434
1351 break
1352 if data is not None:
1353 if encoding is not None:
1354 data = data.decode(encoding)
1355 if DEBUG:
1356 print("DEBUG: DATA = ", data)
1357 if converter is not None:
1358 data = converter(data)
1359 row.append(data)
1360 return tuple(row)
1361
1362 def _get_descriptions(self):
1363 """Read a column descriptor packet for each column in the result."""
1364 self.fields = []
1365 self.converters = []
1366 use_unicode = self.connection.use_unicode
1367 conn_encoding = self.connection.encoding
1368 description = []
1369
1370 for i in range(self.field_count):
1371 field = self.connection._read_packet(FieldDescriptorPacket)
1372 self.fields.append(field)
1373 description.append(field.description())
1374 field_type = field.type_code
1375 if use_unicode:
1376 if field_type == FIELD_TYPE.JSON:
1377 # When SELECT from JSON column: charset = binary
1378 # When SELECT CAST(... AS JSON): charset = connection encoding
1379 # This behavior is different from TEXT / BLOB.
1380 # We should decode result by connection encoding regardless charsetnr.
1381 # See https://github.com/PyMySQL/PyMySQL/issues/488
1382 encoding = conn_encoding # SELECT CAST(... AS JSON)
1383 elif field_type in TEXT_TYPES:
1384 if field.charsetnr == 63: # binary
1385 # TEXTs with charset=binary means BINARY types.
1386 encoding = None
1387 else:
1388 encoding = conn_encoding
1389 else:
1390 # Integers, Dates and Times, and other basic data is encoded in ascii
1391 encoding = "ascii"
1392 else:
1393 encoding = None
1394 converter = self.connection.decoders.get(field_type)
1395 if converter is converters.through:
1396 converter = None
1397 if DEBUG:
1398 print(f"DEBUG: field={field}, converter={converter}")
1399 self.converters.append((encoding, converter))
1400
1401 eof_packet = self.connection._read_packet()
1402 assert eof_packet.is_eof_packet(), "Protocol error, expecting EOF"
1403 self.description = tuple(description)
1404
1405
1406class LoadLocalFile:
1407 def __init__(self, filename, connection):
1408 self.filename = filename
1409 self.connection = connection
1410
1411 def send_data(self):
1412 """Send data packets from the local file to the server"""
1413 if not self.connection._sock:
1414 raise err.InterfaceError(0, "")
1415 conn: Connection = self.connection
1416
1417 try:
1418 with open(self.filename, "rb") as open_file:
1419 packet_size = min(
1420 conn.max_allowed_packet, 16 * 1024
1421 ) # 16KB is efficient enough
1422 while True:
1423 chunk = open_file.read(packet_size)
1424 if not chunk:
1425 break
1426 conn.write_packet(chunk)
1427 except OSError:
1428 raise err.OperationalError(
1429 ER.FILE_NOT_FOUND,
1430 f"Can't find file '{self.filename}'",
1431 )
1432 finally:
1433 if not conn._closed:
1434 # send the empty packet to signify we are done sending data
1435 conn.write_packet(b"")