Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/pymysql/connections.py: 27%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# Python implementation of the MySQL client-server protocol
2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html
3# Error codes:
4# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html
5import errno
6import os
7import socket
8import struct
9import sys
10import traceback
11import warnings
13from . import _auth
15from .charset import charset_by_name, charset_by_id
16from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS
17from . import converters
18from .cursors import Cursor
19from .optionfile import Parser
20from .protocol import (
21 dump_packet,
22 MysqlPacket,
23 FieldDescriptorPacket,
24 OKPacketWrapper,
25 EOFPacketWrapper,
26 LoadLocalPacketWrapper,
27)
28from . import err, VERSION_STRING
30try:
31 import ssl
33 SSL_ENABLED = True
34except ImportError:
35 ssl = None
36 SSL_ENABLED = False
38try:
39 import getpass
41 DEFAULT_USER = getpass.getuser()
42 del getpass
43except (ImportError, KeyError, OSError):
44 # When there's no entry in OS database for a current user:
45 # KeyError is raised in Python 3.12 and below.
46 # OSError is raised in Python 3.13+
47 DEFAULT_USER = None
49DEBUG = False
50_DEFAULT_AUTH_PLUGIN = None # if this is not None, use it instead of server's default.
52TEXT_TYPES = {
53 FIELD_TYPE.BIT,
54 FIELD_TYPE.BLOB,
55 FIELD_TYPE.LONG_BLOB,
56 FIELD_TYPE.MEDIUM_BLOB,
57 FIELD_TYPE.STRING,
58 FIELD_TYPE.TINY_BLOB,
59 FIELD_TYPE.VAR_STRING,
60 FIELD_TYPE.VARCHAR,
61 FIELD_TYPE.GEOMETRY,
62}
65DEFAULT_CHARSET = "utf8mb4"
67MAX_PACKET_LEN = 2**24 - 1
70def _pack_int24(n):
71 return struct.pack("<I", n)[:3]
74# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
75def _lenenc_int(i):
76 if i < 0:
77 raise ValueError(
78 "Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i
79 )
80 elif i < 0xFB:
81 return bytes([i])
82 elif i < (1 << 16):
83 return b"\xfc" + struct.pack("<H", i)
84 elif i < (1 << 24):
85 return b"\xfd" + struct.pack("<I", i)[:3]
86 elif i < (1 << 64):
87 return b"\xfe" + struct.pack("<Q", i)
88 else:
89 raise ValueError(
90 f"Encoding {i:x} is larger than {1 << 64:x} - no representation in LengthEncodedInteger"
91 )
94class Connection:
95 """
96 Representation of a socket with a mysql server.
98 The proper way to get an instance of this class is to call
99 connect().
101 Establish a connection to the MySQL database. Accepts several
102 arguments:
104 :param host: Host where the database server is located.
105 :param user: Username to log in as.
106 :param password: Password to use.
107 :param database: Database to use, None to not use a particular one.
108 :param port: MySQL port to use, default is usually OK. (default: 3306)
109 :param bind_address: When the client has multiple network interfaces, specify
110 the interface from which to connect to the host. Argument can be
111 a hostname or an IP address.
112 :param unix_socket: Use a unix socket rather than TCP/IP.
113 :param read_timeout: The timeout for reading from the connection in seconds.
114 (default: None - no timeout)
115 :param write_timeout: The timeout for writing to the connection in seconds.
116 (default: None - no timeout)
117 :param str charset: Charset to use.
118 :param str collation: Collation name to use.
119 :param sql_mode: Default SQL_MODE to use.
120 :param read_default_file:
121 Specifies my.cnf file to read these parameters from under the [client] section.
122 :param conv:
123 Conversion dictionary to use instead of the default one.
124 This is used to provide custom marshalling and unmarshalling of types.
125 See converters.
126 :param use_unicode:
127 Whether or not to default to unicode strings.
128 This option defaults to true.
129 :param client_flag: Custom flags to send to MySQL. Find potential values in constants.CLIENT.
130 :param cursorclass: Custom cursor class to use.
131 :param init_command: Initial SQL statement to run when connection is established.
132 :param connect_timeout: The timeout for connecting to the database in seconds.
133 (default: 10, min: 1, max: 31536000)
134 :param ssl: An ssl.SSLContext, or a dict of arguments similar to mysql_ssl_set()'s parameters.
135 Passing a dict is deprecated; use the individual ``ssl_*`` parameters or an
136 ``ssl.SSLContext`` instead.
137 :param ssl_ca: Path to the file that contains a PEM-formatted CA certificate.
138 :param ssl_cert: Path to the file that contains a PEM-formatted client certificate.
139 :param ssl_disabled: A boolean value that disables usage of TLS. Unlike other SSL options,
140 setting this to True explicitly prohibits the use of TLS, even if the server supports it.
141 :param ssl_key: Path to the file that contains a PEM-formatted private key for
142 the client certificate.
143 :param ssl_key_password: The password for the client certificate private key.
144 :param ssl_verify_cert: Set to true to check the server certificate's validity.
145 :param ssl_verify_identity: Set to true to check the server's identity.
146 :param read_default_group: Group to read from in the configuration file.
147 :param autocommit: Autocommit mode. None means use server default. (default: False)
148 :param local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
149 :param max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
150 Only used to limit size of "LOAD LOCAL INFILE" data packet smaller than default (16KB).
151 :param defer_connect: Don't explicitly connect on construction - wait for connect call.
152 (default: False)
153 :param auth_plugin_map: A dict of plugin names to a class that processes that plugin.
154 The class will take the Connection object as the argument to the constructor.
155 The class needs an authenticate method taking an authentication packet as
156 an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
157 (if no authenticate method) for returning a string from the user. (experimental)
158 :param server_public_key: SHA256 authentication plugin public key value. (default: None)
159 :param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False)
160 :param compress: Not supported.
161 :param named_pipe: Not supported.
162 :param db: **DEPRECATED** Alias for database.
163 :param passwd: **DEPRECATED** Alias for password.
165 See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_ in the
166 specification.
167 """
169 _sock = None
170 _rfile = None
171 _auth_plugin_name = ""
172 _closed = False
173 _secure = False
175 def __init__(
176 self,
177 *,
178 user=None, # The first four arguments is based on DB-API 2.0 recommendation.
179 password="",
180 host=None,
181 database=None,
182 unix_socket=None,
183 port=0,
184 charset="",
185 collation=None,
186 sql_mode=None,
187 read_default_file=None,
188 conv=None,
189 use_unicode=True,
190 client_flag=0,
191 cursorclass=Cursor,
192 init_command=None,
193 connect_timeout=10,
194 read_default_group=None,
195 autocommit=False,
196 local_infile=False,
197 max_allowed_packet=16 * 1024 * 1024,
198 defer_connect=False,
199 auth_plugin_map=None,
200 read_timeout=None,
201 write_timeout=None,
202 bind_address=None,
203 binary_prefix=False,
204 program_name=None,
205 server_public_key=None,
206 ssl=None,
207 ssl_ca=None,
208 ssl_cert=None,
209 ssl_disabled=None,
210 ssl_key=None,
211 ssl_key_password=None,
212 ssl_verify_cert=None,
213 ssl_verify_identity=None,
214 compress=None, # not supported
215 named_pipe=None, # not supported
216 passwd=None, # deprecated
217 db=None, # deprecated
218 ):
219 if db is not None and database is None:
220 warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3)
221 database = db
222 if passwd is not None and not password:
223 warnings.warn(
224 "'passwd' is deprecated, use 'password'", DeprecationWarning, 3
225 )
226 password = passwd
228 if compress or named_pipe:
229 raise NotImplementedError(
230 "compress and named_pipe arguments are not supported"
231 )
233 self._local_infile = bool(local_infile)
234 if self._local_infile:
235 client_flag |= CLIENT.LOCAL_FILES
237 if read_default_group and not read_default_file:
238 if sys.platform.startswith("win"):
239 read_default_file = "c:\\my.ini"
240 else:
241 read_default_file = "/etc/my.cnf"
243 if read_default_file:
244 if not read_default_group:
245 read_default_group = "client"
247 cfg = Parser()
248 cfg.read(os.path.expanduser(read_default_file))
250 def _config(key, arg):
251 if arg:
252 return arg
253 try:
254 return cfg.get(read_default_group, key)
255 except Exception:
256 return arg
258 user = _config("user", user)
259 password = _config("password", password)
260 host = _config("host", host)
261 database = _config("database", database)
262 unix_socket = _config("socket", unix_socket)
263 port = int(_config("port", port))
264 bind_address = _config("bind-address", bind_address)
265 charset = _config("default-character-set", charset)
266 if not ssl:
267 ssl = {}
268 if isinstance(ssl, dict):
269 for key in ["ca", "capath", "cert", "key", "password", "cipher"]:
270 value = _config("ssl-" + key, ssl.get(key))
271 if value:
272 ssl[key] = value
274 self.ssl = False
275 self._ssl_required = False
276 if not ssl_disabled:
277 if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity:
278 ssl = {
279 "ca": ssl_ca,
280 "check_hostname": bool(ssl_verify_identity),
281 "verify_mode": ssl_verify_cert
282 if ssl_verify_cert is not None
283 else False,
284 }
285 if ssl_cert is not None:
286 ssl["cert"] = ssl_cert
287 if ssl_key is not None:
288 ssl["key"] = ssl_key
289 if ssl_key_password is not None:
290 ssl["password"] = ssl_key_password
291 if ssl:
292 if not SSL_ENABLED:
293 raise NotImplementedError("ssl module not found")
294 self.ssl = True
295 self._ssl_required = True
296 client_flag |= CLIENT.SSL
297 self.ctx = self._create_ssl_ctx(ssl)
298 elif SSL_ENABLED:
299 # No explicit SSL options specified: use PREFERRED mode.
300 # Attempt SSL but fall back gracefully if the server doesn't support it.
301 self.ssl = True
302 self._ssl_required = False
303 self.ctx = self._create_ssl_ctx({})
305 self.host = host or "localhost"
306 self.port = port or 3306
307 if type(self.port) is not int:
308 raise ValueError("port should be of type int")
309 self.user = user or DEFAULT_USER
310 self.password = password or b""
311 if isinstance(self.password, str):
312 self.password = self.password.encode("latin1")
313 self.db = database
314 self.unix_socket = unix_socket
315 self.bind_address = bind_address
316 if not (0 < connect_timeout <= 31536000):
317 raise ValueError("connect_timeout should be >0 and <=31536000")
318 self.connect_timeout = connect_timeout or None
319 if read_timeout is not None and read_timeout <= 0:
320 raise ValueError("read_timeout should be > 0")
321 self._read_timeout = read_timeout
322 if write_timeout is not None and write_timeout <= 0:
323 raise ValueError("write_timeout should be > 0")
324 self._write_timeout = write_timeout
326 self.charset = charset or DEFAULT_CHARSET
327 self.collation = collation
328 self.use_unicode = use_unicode
330 self.encoding = charset_by_name(self.charset).encoding
332 client_flag |= CLIENT.CAPABILITIES
333 if self.db:
334 client_flag |= CLIENT.CONNECT_WITH_DB
336 self.client_flag = client_flag
338 self.cursorclass = cursorclass
340 self._result = None
341 self._affected_rows = 0
342 self.host_info = "Not connected"
344 # specified autocommit mode. None means use server default.
345 self.autocommit_mode = autocommit
347 if conv is None:
348 conv = converters.conversions
350 # Need for MySQLdb compatibility.
351 self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int}
352 self.decoders = {k: v for (k, v) in conv.items() if type(k) is int}
353 self.sql_mode = sql_mode
354 self.init_command = init_command
355 self.max_allowed_packet = max_allowed_packet
356 self._auth_plugin_map = auth_plugin_map or {}
357 self._binary_prefix = binary_prefix
358 self.server_public_key = server_public_key
360 self._connect_attrs = {
361 "_client_name": "pymysql",
362 "_client_version": VERSION_STRING,
363 "_pid": str(os.getpid()),
364 }
366 if program_name:
367 self._connect_attrs["program_name"] = program_name
369 if defer_connect:
370 self._sock = None
371 else:
372 self.connect()
374 def __enter__(self):
375 return self
377 def __exit__(self, *exc_info):
378 del exc_info
379 self.close()
381 def _create_ssl_ctx(self, sslp):
382 if isinstance(sslp, ssl.SSLContext):
383 return sslp
384 ca = sslp.get("ca")
385 capath = sslp.get("capath")
386 hasnoca = ca is None and capath is None
387 ctx = ssl.create_default_context(cafile=ca, capath=capath)
389 # Python 3.13 enables VERIFY_X509_STRICT by default.
390 # But self signed certificates that are generated by MySQL automatically
391 # doesn't pass the verification.
392 ctx.verify_flags &= ~ssl.VERIFY_X509_STRICT
394 ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True)
395 verify_mode_value = sslp.get("verify_mode")
396 if verify_mode_value is None:
397 ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
398 elif isinstance(verify_mode_value, bool):
399 ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE
400 else:
401 if isinstance(verify_mode_value, str):
402 verify_mode_value = verify_mode_value.lower()
403 if verify_mode_value in ("none", "0", "false", "no"):
404 ctx.verify_mode = ssl.CERT_NONE
405 elif verify_mode_value == "optional":
406 ctx.verify_mode = ssl.CERT_OPTIONAL
407 elif verify_mode_value in ("required", "1", "true", "yes"):
408 ctx.verify_mode = ssl.CERT_REQUIRED
409 else:
410 ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
411 if "cert" in sslp:
412 ctx.load_cert_chain(
413 sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password")
414 )
415 if "cipher" in sslp:
416 ctx.set_ciphers(sslp["cipher"])
417 ctx.options |= ssl.OP_NO_SSLv2
418 ctx.options |= ssl.OP_NO_SSLv3
419 return ctx
421 def close(self):
422 """
423 Send the quit message and close the socket.
425 See `Connection.close() <https://www.python.org/dev/peps/pep-0249/#Connection.close>`_
426 in the specification.
428 :raise Error: If the connection is already closed.
429 """
430 if self._closed:
431 raise err.Error("Already closed")
432 self._closed = True
433 if self._sock is None:
434 return
435 send_data = struct.pack("<iB", 1, COMMAND.COM_QUIT)
436 try:
437 self._write_bytes(send_data)
438 except Exception:
439 pass
440 finally:
441 self._force_close()
443 @property
444 def open(self):
445 """Return True if the connection is open."""
446 return self._sock is not None
448 def _force_close(self):
449 """Close connection without QUIT message."""
450 if self._rfile:
451 self._rfile.close()
452 if self._sock:
453 try:
454 self._sock.close()
455 except: # noqa
456 pass
457 self._sock = None
458 self._rfile = None
460 __del__ = _force_close
462 def autocommit(self, value):
463 self.autocommit_mode = bool(value)
464 current = self.get_autocommit()
465 if value != current:
466 self._send_autocommit_mode()
468 def get_autocommit(self):
469 return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT)
471 def _read_ok_packet(self):
472 pkt = self._read_packet()
473 if not pkt.is_ok_packet():
474 raise err.OperationalError(
475 CR.CR_COMMANDS_OUT_OF_SYNC,
476 "Command Out of Sync",
477 )
478 ok = OKPacketWrapper(pkt)
479 self.server_status = ok.server_status
480 return ok
482 def _send_autocommit_mode(self):
483 """Set whether or not to commit after every execute()."""
484 self._execute_command(
485 COMMAND.COM_QUERY, "SET AUTOCOMMIT = %s" % self.escape(self.autocommit_mode)
486 )
487 self._read_ok_packet()
489 def begin(self):
490 """Begin transaction."""
491 self._execute_command(COMMAND.COM_QUERY, "BEGIN")
492 self._read_ok_packet()
494 def commit(self):
495 """
496 Commit changes to stable storage.
498 See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_
499 in the specification.
500 """
501 self._execute_command(COMMAND.COM_QUERY, "COMMIT")
502 self._read_ok_packet()
504 def rollback(self):
505 """
506 Roll back the current transaction.
508 See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_
509 in the specification.
510 """
511 self._execute_command(COMMAND.COM_QUERY, "ROLLBACK")
512 self._read_ok_packet()
514 def show_warnings(self):
515 """Send the "SHOW WARNINGS" SQL command."""
516 self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS")
517 result = MySQLResult(self)
518 result.read()
519 return result.rows
521 def select_db(self, db):
522 """
523 Set current db.
525 :param db: The name of the db.
526 """
527 self._execute_command(COMMAND.COM_INIT_DB, db)
528 self._read_ok_packet()
530 def escape(self, obj, mapping=None):
531 """Escape whatever value is passed.
533 Non-standard, for internal use; do not use this in your applications.
534 """
535 if isinstance(obj, str):
536 return "'" + self.escape_string(obj) + "'"
537 if isinstance(obj, (bytes, bytearray)):
538 ret = self._quote_bytes(obj)
539 if self._binary_prefix:
540 ret = "_binary" + ret
541 return ret
542 return converters.escape_item(obj, self.charset, mapping=mapping)
544 def literal(self, obj):
545 """Alias for escape().
547 Non-standard, for internal use; do not use this in your applications.
548 """
549 return self.escape(obj, self.encoders)
551 def escape_string(self, s):
552 if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
553 return s.replace("'", "''")
554 return converters.escape_string(s)
556 def _quote_bytes(self, s):
557 if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES:
558 return "'{}'".format(
559 s.replace(b"'", b"''").decode("ascii", "surrogateescape")
560 )
561 return converters.escape_bytes(s)
563 def cursor(self, cursor=None):
564 """
565 Create a new cursor to execute queries with.
567 :param cursor: The type of cursor to create. None means use Cursor.
568 :type cursor: :py:class:`Cursor`, :py:class:`SSCursor`, :py:class:`DictCursor`,
569 or :py:class:`SSDictCursor`.
570 """
571 if cursor:
572 return cursor(self)
573 return self.cursorclass(self)
575 # The following methods are INTERNAL USE ONLY (called from Cursor)
576 def query(self, sql, unbuffered=False):
577 # if DEBUG:
578 # print("DEBUG: sending query:", sql)
579 if isinstance(sql, str):
580 sql = sql.encode(self.encoding, "surrogateescape")
581 self._execute_command(COMMAND.COM_QUERY, sql)
582 self._affected_rows = self._read_query_result(unbuffered=unbuffered)
583 return self._affected_rows
585 def next_result(self, unbuffered=False):
586 self._affected_rows = self._read_query_result(unbuffered=unbuffered)
587 return self._affected_rows
589 def affected_rows(self):
590 return self._affected_rows
592 def kill(self, thread_id):
593 if not isinstance(thread_id, int):
594 raise TypeError("thread_id must be an integer")
595 self.query(f"KILL {thread_id:d}")
597 def ping(self, reconnect=False):
598 """
599 Check if the server is alive.
601 `reconnect` is deprecated. Create a new connection if you want to reconnect.
603 :param reconnect: If the connection is closed, reconnect.
604 :type reconnect: boolean
606 :raise Error: If the connection is closed and reconnect=False.
607 """
608 # emit deprecation warning for reconnect.
609 if reconnect:
610 warnings.warn(
611 "The 'reconnect' argument is deprecated. Create a new connection if you want to reconnect.",
612 DeprecationWarning,
613 2,
614 )
615 if self._sock is None:
616 if reconnect:
617 self.connect()
618 reconnect = False
619 else:
620 raise err.Error("Already closed")
621 try:
622 self._execute_command(COMMAND.COM_PING, "")
623 self._read_ok_packet()
624 except Exception:
625 if reconnect:
626 self.connect()
627 self.ping(False)
628 else:
629 raise
631 def set_charset(self, charset):
632 """Deprecated. Use set_character_set() instead."""
633 warnings.warn(
634 "'set_charset' is deprecated, use 'set_character_set' instead",
635 DeprecationWarning,
636 2,
637 )
638 # This function has been implemented in old PyMySQL.
639 # But this name is different from MySQLdb.
640 # So we keep this function for compatibility and add
641 # new set_character_set() function.
642 self.set_character_set(charset)
644 def set_character_set(self, charset, collation=None):
645 """
646 Set charaset (and collation)
648 Send "SET NAMES charset [COLLATE collation]" query.
649 Update Connection.encoding based on charset.
650 """
651 # Make sure charset is supported.
652 encoding = charset_by_name(charset).encoding
654 if collation:
655 query = f"SET NAMES {charset} COLLATE {collation}"
656 else:
657 query = f"SET NAMES {charset}"
658 self._execute_command(COMMAND.COM_QUERY, query)
659 self._read_packet()
660 self.charset = charset
661 self.encoding = encoding
662 self.collation = collation
664 def connect(self, sock=None):
665 self._closed = False
666 try:
667 if sock is None:
668 if self.unix_socket:
669 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
670 sock.settimeout(self.connect_timeout)
671 sock.connect(self.unix_socket)
672 self.host_info = "Localhost via UNIX socket"
673 self._secure = True
674 if DEBUG:
675 print("connected using unix_socket")
676 else:
677 kwargs = {}
678 if self.bind_address is not None:
679 kwargs["source_address"] = (self.bind_address, 0)
680 while True:
681 try:
682 sock = socket.create_connection(
683 (self.host, self.port), self.connect_timeout, **kwargs
684 )
685 break
686 except OSError as e:
687 if e.errno == errno.EINTR:
688 continue
689 raise
690 self.host_info = "socket %s:%d" % (self.host, self.port)
691 if DEBUG:
692 print("connected using socket")
693 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
694 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
695 sock.settimeout(None)
697 self._sock = sock
698 self._rfile = sock.makefile("rb")
699 self._next_seq_id = 0
701 self._get_server_information()
702 self._request_authentication()
704 # Send "SET NAMES" query on init for:
705 # - Ensure charaset (and collation) is set to the server.
706 # - collation_id in handshake packet may be ignored.
707 # - If collation is not specified, we don't know what is server's
708 # default collation for the charset. For example, default collation
709 # of utf8mb4 is:
710 # - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci
711 # - MySQL 8.0: utf8mb4_0900_ai_ci
712 #
713 # Reference:
714 # - https://github.com/PyMySQL/PyMySQL/issues/1092
715 # - https://github.com/wagtail/wagtail/issues/9477
716 # - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese)
717 self.set_character_set(self.charset, self.collation)
719 if self.sql_mode is not None:
720 c = self.cursor()
721 c.execute("SET sql_mode=%s", (self.sql_mode,))
722 c.close()
724 if self.init_command is not None:
725 c = self.cursor()
726 c.execute(self.init_command)
727 c.close()
729 if self.autocommit_mode is not None:
730 self.autocommit(self.autocommit_mode)
731 except BaseException as e:
732 self._force_close()
734 if isinstance(e, (OSError, IOError)):
735 exc = err.OperationalError(
736 CR.CR_CONN_HOST_ERROR,
737 f"Can't connect to MySQL server on {self.host!r} ({e})",
738 )
739 # Keep original exception and traceback to investigate error.
740 exc.original_exception = e
741 exc.traceback = traceback.format_exc()
742 if DEBUG:
743 print(exc.traceback)
744 raise exc
746 # If e is neither DatabaseError or IOError, It's a bug.
747 # But raising AssertionError hides original error.
748 # So just reraise it.
749 raise
751 def write_packet(self, payload):
752 """Writes an entire "mysql packet" in its entirety to the network
753 adding its length and sequence number.
754 """
755 # Internal note: when you build packet manually and calls _write_bytes()
756 # directly, you should set self._next_seq_id properly.
757 data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload
758 if DEBUG:
759 dump_packet(data)
760 self._write_bytes(data)
761 self._next_seq_id = (self._next_seq_id + 1) % 256
763 def _read_packet(self, packet_type=MysqlPacket):
764 """Read an entire "mysql packet" in its entirety from the network
765 and return a MysqlPacket type that represents the results.
767 :raise OperationalError: If the connection to the MySQL server is lost.
768 :raise InternalError: If the packet sequence number is wrong.
769 """
770 buff = bytearray()
771 while True:
772 packet_header = self._read_bytes(4)
773 # if DEBUG: dump_packet(packet_header)
775 btrl, btrh, packet_number = struct.unpack("<HBB", packet_header)
776 bytes_to_read = btrl + (btrh << 16)
777 if packet_number != self._next_seq_id:
778 self._force_close()
779 if packet_number == 0:
780 # MariaDB sends error packet with seqno==0 when shutdown
781 raise err.OperationalError(
782 CR.CR_SERVER_LOST,
783 "Lost connection to MySQL server during query",
784 )
785 raise err.InternalError(
786 "Packet sequence number wrong - got %d expected %d"
787 % (packet_number, self._next_seq_id)
788 )
789 self._next_seq_id = (self._next_seq_id + 1) % 256
791 recv_data = self._read_bytes(bytes_to_read)
792 if DEBUG:
793 dump_packet(recv_data)
794 buff += recv_data
795 # https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
796 if bytes_to_read < MAX_PACKET_LEN:
797 break
799 packet = packet_type(bytes(buff), self.encoding)
800 if packet.is_error_packet():
801 if self._result is not None and self._result.unbuffered_active is True:
802 self._result.unbuffered_active = False
803 packet.raise_for_error()
804 return packet
806 def _read_bytes(self, num_bytes):
807 self._sock.settimeout(self._read_timeout)
808 while True:
809 try:
810 data = self._rfile.read(num_bytes)
811 break
812 except OSError as e:
813 if e.errno == errno.EINTR:
814 continue
815 self._force_close()
816 raise err.OperationalError(
817 CR.CR_SERVER_LOST,
818 f"Lost connection to MySQL server during query ({e})",
819 )
820 except BaseException:
821 # Don't convert unknown exception to MySQLError.
822 self._force_close()
823 raise
824 if len(data) < num_bytes:
825 self._force_close()
826 raise err.OperationalError(
827 CR.CR_SERVER_LOST, "Lost connection to MySQL server during query"
828 )
829 return data
831 def _write_bytes(self, data):
832 self._sock.settimeout(self._write_timeout)
833 try:
834 self._sock.sendall(data)
835 except OSError as e:
836 self._force_close()
837 raise err.OperationalError(
838 CR.CR_SERVER_GONE_ERROR, f"MySQL server has gone away ({e!r})"
839 )
841 def _read_query_result(self, unbuffered=False):
842 self._result = None
843 result = MySQLResult(self)
844 if unbuffered:
845 result.init_unbuffered_query()
846 else:
847 result.read()
848 self._result = result
849 if result.server_status is not None:
850 self.server_status = result.server_status
851 return result.affected_rows
853 def insert_id(self):
854 if self._result:
855 return self._result.insert_id
856 else:
857 return 0
859 def _execute_command(self, command, sql):
860 """
861 :raise InterfaceError: If the connection is closed.
862 :raise ValueError: If no username was specified.
863 """
864 if not self._sock:
865 raise err.InterfaceError(0, "")
867 # If the last query was unbuffered, make sure it finishes before
868 # sending new commands
869 if self._result is not None:
870 if self._result.unbuffered_active:
871 warnings.warn("Previous unbuffered result was left incomplete")
872 self._result._finish_unbuffered_query()
873 while self._result.has_next:
874 self.next_result()
875 self._result = None
877 if isinstance(sql, str):
878 sql = sql.encode(self.encoding)
880 packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
882 # tiny optimization: build first packet manually instead of
883 # calling self..write_packet()
884 prelude = struct.pack("<iB", packet_size, command)
885 packet = prelude + sql[: packet_size - 1]
886 self._write_bytes(packet)
887 if DEBUG:
888 dump_packet(packet)
889 self._next_seq_id = 1
891 if packet_size < MAX_PACKET_LEN:
892 return
894 sql = sql[packet_size - 1 :]
895 while True:
896 packet_size = min(MAX_PACKET_LEN, len(sql))
897 self.write_packet(sql[:packet_size])
898 sql = sql[packet_size:]
899 if not sql and packet_size < MAX_PACKET_LEN:
900 break
902 def _request_authentication(self):
903 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
904 if int(self.server_version.split(".", 1)[0]) >= 5:
905 self.client_flag |= CLIENT.MULTI_RESULTS
907 if self.user is None:
908 raise ValueError("Did not specify a username")
910 charset_id = charset_by_name(self.charset).id
911 if isinstance(self.user, str):
912 self.user = self.user.encode(self.encoding)
914 # Determine flags for the initial handshake packet.
915 # CLIENT.SSL is added conditionally: for REQUIRED mode it is already set in
916 # self.client_flag, but for PREFERRED mode it is only added when the server
917 # also advertises SSL support.
918 # _do_ssl is set here and checked below for sha256_password auth.
919 client_flags = self.client_flag
920 if self.ssl:
921 if self.server_capabilities & CLIENT.SSL:
922 # SSL upgrade: include CLIENT.SSL flag and wrap the socket.
923 _do_ssl = True
924 client_flags |= CLIENT.SSL
925 elif self._ssl_required:
926 raise err.OperationalError(
927 CR.CR_SSL_CONNECTION_ERROR,
928 "SSL is required but the server doesn't support it",
929 )
930 else:
931 # PREFERRED mode: server doesn't support SSL, fall back to non-SSL.
932 _do_ssl = False
933 else:
934 _do_ssl = False
936 data_init = struct.pack(
937 "<iIB23s", client_flags, MAX_PACKET_LEN, charset_id, b""
938 )
940 if _do_ssl:
941 self.write_packet(data_init)
942 self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host)
943 self._rfile = self._sock.makefile("rb")
944 self._secure = True
946 data = data_init + self.user + b"\0"
948 authresp = b""
949 plugin_name = None
951 if self._auth_plugin_name == "":
952 plugin_name = b""
953 authresp = _auth.scramble_native_password(self.password, self.salt)
954 elif self._auth_plugin_name == "mysql_native_password":
955 plugin_name = b"mysql_native_password"
956 authresp = _auth.scramble_native_password(self.password, self.salt)
957 elif self._auth_plugin_name == "caching_sha2_password":
958 plugin_name = b"caching_sha2_password"
959 if self.password:
960 if DEBUG:
961 print("caching_sha2: trying fast path")
962 authresp = _auth.scramble_caching_sha2(self.password, self.salt)
963 else:
964 if DEBUG:
965 print("caching_sha2: empty password")
966 elif self._auth_plugin_name == "sha256_password":
967 plugin_name = b"sha256_password"
968 if _do_ssl:
969 authresp = self.password + b"\0"
970 elif self.password:
971 authresp = b"\1" # request public key
972 else:
973 authresp = b"\0" # empty password
975 if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
976 data += _lenenc_int(len(authresp)) + authresp
977 elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
978 data += struct.pack("B", len(authresp)) + authresp
979 else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
980 data += authresp + b"\0"
982 if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
983 if isinstance(self.db, str):
984 self.db = self.db.encode(self.encoding)
985 data += self.db + b"\0"
987 if self.server_capabilities & CLIENT.PLUGIN_AUTH:
988 data += (plugin_name or b"") + b"\0"
990 if self.server_capabilities & CLIENT.CONNECT_ATTRS:
991 connect_attrs = b""
992 for k, v in self._connect_attrs.items():
993 k = k.encode("utf-8")
994 connect_attrs += _lenenc_int(len(k)) + k
995 v = v.encode("utf-8")
996 connect_attrs += _lenenc_int(len(v)) + v
997 data += _lenenc_int(len(connect_attrs)) + connect_attrs
999 self.write_packet(data)
1000 auth_packet = self._read_packet()
1002 # if authentication method isn't accepted the first byte
1003 # will have the octet 254
1004 if auth_packet.is_auth_switch_request():
1005 if DEBUG:
1006 print("received auth switch")
1007 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
1008 auth_packet.read_uint8() # 0xfe packet identifier
1009 plugin_name = auth_packet.read_string()
1010 if (
1011 self.server_capabilities & CLIENT.PLUGIN_AUTH
1012 and plugin_name is not None
1013 ):
1014 auth_packet = self._process_auth(plugin_name, auth_packet)
1015 else:
1016 raise err.OperationalError("received unknown auth switch request")
1017 elif auth_packet.is_extra_auth_data():
1018 if DEBUG:
1019 print("received extra data")
1020 # https://dev.mysql.com/doc/internals/en/successful-authentication.html
1021 if self._auth_plugin_name == "caching_sha2_password":
1022 auth_packet = _auth.caching_sha2_password_auth(self, auth_packet)
1023 elif self._auth_plugin_name == "sha256_password":
1024 auth_packet = _auth.sha256_password_auth(self, auth_packet)
1025 else:
1026 raise err.OperationalError(
1027 "Received extra packet for auth method %r", self._auth_plugin_name
1028 )
1030 if DEBUG:
1031 print("Succeed to auth")
1033 def _process_auth(self, plugin_name, auth_packet):
1034 handler = self._get_auth_plugin_handler(plugin_name)
1035 if handler:
1036 try:
1037 return handler.authenticate(auth_packet)
1038 except AttributeError:
1039 if plugin_name != b"dialog":
1040 raise err.OperationalError(
1041 CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1042 f"Authentication plugin '{plugin_name}'"
1043 f" not loaded: - {type(handler)!r} missing authenticate method",
1044 )
1045 if plugin_name == b"caching_sha2_password":
1046 return _auth.caching_sha2_password_auth(self, auth_packet)
1047 elif plugin_name == b"sha256_password":
1048 return _auth.sha256_password_auth(self, auth_packet)
1049 elif plugin_name == b"mysql_native_password":
1050 data = _auth.scramble_native_password(self.password, auth_packet.read_all())
1051 elif plugin_name == b"client_ed25519":
1052 data = _auth.ed25519_password(self.password, auth_packet.read_all())
1053 elif plugin_name == b"mysql_old_password":
1054 data = (
1055 _auth.scramble_old_password(self.password, auth_packet.read_all())
1056 + b"\0"
1057 )
1058 elif plugin_name == b"mysql_clear_password":
1059 # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
1060 data = self.password + b"\0"
1061 elif plugin_name == b"dialog":
1062 pkt = auth_packet
1063 while True:
1064 flag = pkt.read_uint8()
1065 echo = (flag & 0x06) == 0x02
1066 last = (flag & 0x01) == 0x01
1067 prompt = pkt.read_all()
1069 if prompt == b"Password: ":
1070 self.write_packet(self.password + b"\0")
1071 elif handler:
1072 resp = "no response - TypeError within plugin.prompt method"
1073 try:
1074 resp = handler.prompt(echo, prompt)
1075 self.write_packet(resp + b"\0")
1076 except AttributeError:
1077 raise err.OperationalError(
1078 CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1079 f"Authentication plugin '{plugin_name}'"
1080 f" not loaded: - {handler!r} missing prompt method",
1081 )
1082 except TypeError:
1083 raise err.OperationalError(
1084 CR.CR_AUTH_PLUGIN_ERR,
1085 f"Authentication plugin '{plugin_name}'"
1086 f" {handler!r} didn't respond with string. Returned '{resp!r}' to prompt {prompt!r}",
1087 )
1088 else:
1089 raise err.OperationalError(
1090 CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1091 f"Authentication plugin '{plugin_name}' not configured",
1092 )
1093 pkt = self._read_packet()
1094 pkt.check_error()
1095 if pkt.is_ok_packet() or last:
1096 break
1097 return pkt
1098 else:
1099 raise err.OperationalError(
1100 CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1101 "Authentication plugin '%s' not configured" % plugin_name,
1102 )
1104 self.write_packet(data)
1105 pkt = self._read_packet()
1106 pkt.check_error()
1107 return pkt
1109 def _get_auth_plugin_handler(self, plugin_name):
1110 plugin_class = self._auth_plugin_map.get(plugin_name)
1111 if not plugin_class and isinstance(plugin_name, bytes):
1112 plugin_class = self._auth_plugin_map.get(plugin_name.decode("ascii"))
1113 if plugin_class:
1114 try:
1115 handler = plugin_class(self)
1116 except TypeError:
1117 raise err.OperationalError(
1118 CR.CR_AUTH_PLUGIN_CANNOT_LOAD,
1119 f"Authentication plugin '{plugin_name}'"
1120 f" not loaded: - {plugin_class!r} cannot be constructed with connection object",
1121 )
1122 else:
1123 handler = None
1124 return handler
1126 # _mysql support
1127 def thread_id(self):
1128 return self.server_thread_id[0]
1130 def character_set_name(self):
1131 return self.charset
1133 def get_host_info(self):
1134 return self.host_info
1136 def get_proto_info(self):
1137 return self.protocol_version
1139 def _get_server_information(self):
1140 i = 0
1141 packet = self._read_packet()
1142 data = packet.get_all_data()
1144 self.protocol_version = data[i]
1145 i += 1
1147 server_end = data.find(b"\0", i)
1148 self.server_version = data[i:server_end].decode("latin1")
1149 i = server_end + 1
1151 self.server_thread_id = struct.unpack("<I", data[i : i + 4])
1152 i += 4
1154 self.salt = data[i : i + 8]
1155 i += 9 # 8 + 1(filler)
1157 self.server_capabilities = struct.unpack("<H", data[i : i + 2])[0]
1158 i += 2
1160 if len(data) >= i + 6:
1161 lang, stat, cap_h, salt_len = struct.unpack("<BHHB", data[i : i + 6])
1162 i += 6
1163 # TODO: deprecate server_language and server_charset.
1164 # mysqlclient-python doesn't provide it.
1165 self.server_language = lang
1166 try:
1167 self.server_charset = charset_by_id(lang).name
1168 except KeyError:
1169 # unknown collation
1170 self.server_charset = None
1172 self.server_status = stat
1173 if DEBUG:
1174 print("server_status: %x" % stat)
1176 self.server_capabilities |= cap_h << 16
1177 if DEBUG:
1178 print("salt_len:", salt_len)
1179 salt_len = max(12, salt_len - 9)
1181 # reserved
1182 i += 10
1184 if len(data) >= i + salt_len:
1185 # salt_len includes auth_plugin_data_part_1 and filler
1186 self.salt += data[i : i + salt_len]
1187 i += salt_len
1189 i += 1
1190 # AUTH PLUGIN NAME may appear here.
1191 if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
1192 # Due to Bug#59453 the auth-plugin-name is missing the terminating
1193 # NUL-char in versions prior to 5.5.10 and 5.6.2.
1194 # ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
1195 # didn't use version checks as mariadb is corrected and reports
1196 # earlier than those two.
1197 server_end = data.find(b"\0", i)
1198 if server_end < 0: # pragma: no cover - very specific upstream bug
1199 # not found \0 and last field so take it all
1200 self._auth_plugin_name = data[i:].decode("utf-8")
1201 else:
1202 self._auth_plugin_name = data[i:server_end].decode("utf-8")
1204 if _DEFAULT_AUTH_PLUGIN is not None: # for tests
1205 self._auth_plugin_name = _DEFAULT_AUTH_PLUGIN
1207 def get_server_info(self):
1208 return self.server_version
1210 Warning = err.Warning
1211 Error = err.Error
1212 InterfaceError = err.InterfaceError
1213 DatabaseError = err.DatabaseError
1214 DataError = err.DataError
1215 OperationalError = err.OperationalError
1216 IntegrityError = err.IntegrityError
1217 InternalError = err.InternalError
1218 ProgrammingError = err.ProgrammingError
1219 NotSupportedError = err.NotSupportedError
1222class MySQLResult:
1223 def __init__(self, connection):
1224 """
1225 :type connection: Connection
1226 """
1227 self.connection = connection
1228 self.affected_rows = None
1229 self.insert_id = None
1230 self.server_status = None
1231 self.warning_count = 0
1232 self.message = None
1233 self.field_count = 0
1234 self.description = None
1235 self.rows = None
1236 self.has_next = None
1237 self.unbuffered_active = False
1239 def __del__(self):
1240 if self.unbuffered_active:
1241 self._finish_unbuffered_query()
1243 def read(self):
1244 try:
1245 first_packet = self.connection._read_packet()
1247 if first_packet.is_ok_packet():
1248 self._read_ok_packet(first_packet)
1249 elif first_packet.is_load_local_packet():
1250 self._read_load_local_packet(first_packet)
1251 else:
1252 self._read_result_packet(first_packet)
1253 finally:
1254 self.connection = None
1256 def init_unbuffered_query(self):
1257 """
1258 :raise OperationalError: If the connection to the MySQL server is lost.
1259 :raise InternalError:
1260 """
1261 first_packet = self.connection._read_packet()
1263 if first_packet.is_ok_packet():
1264 self.connection = None
1265 self._read_ok_packet(first_packet)
1266 elif first_packet.is_load_local_packet():
1267 try:
1268 self._read_load_local_packet(first_packet)
1269 finally:
1270 self.connection = None
1271 else:
1272 self.field_count = first_packet.read_length_encoded_integer()
1273 self._get_descriptions()
1275 # Apparently, MySQLdb picks this number because it's the maximum
1276 # value of a 64bit unsigned integer. Since we're emulating MySQLdb,
1277 # we set it to this instead of None, which would be preferred.
1278 self.affected_rows = 18446744073709551615
1279 self.unbuffered_active = True
1281 def _read_ok_packet(self, first_packet):
1282 ok_packet = OKPacketWrapper(first_packet)
1283 self.affected_rows = ok_packet.affected_rows
1284 self.insert_id = ok_packet.insert_id
1285 self.server_status = ok_packet.server_status
1286 self.warning_count = ok_packet.warning_count
1287 self.message = ok_packet.message
1288 self.has_next = ok_packet.has_next
1290 def _read_load_local_packet(self, first_packet):
1291 if not self.connection._local_infile:
1292 raise RuntimeError(
1293 "**WARN**: Received LOAD_LOCAL packet but local_infile option is false."
1294 )
1295 load_packet = LoadLocalPacketWrapper(first_packet)
1296 sender = LoadLocalFile(load_packet.filename, self.connection)
1297 try:
1298 sender.send_data()
1299 except:
1300 self.connection._read_packet() # skip ok packet
1301 raise
1303 ok_packet = self.connection._read_packet()
1304 if (
1305 not ok_packet.is_ok_packet()
1306 ): # pragma: no cover - upstream induced protocol error
1307 raise err.OperationalError(
1308 CR.CR_COMMANDS_OUT_OF_SYNC,
1309 "Commands Out of Sync",
1310 )
1311 self._read_ok_packet(ok_packet)
1313 def _check_packet_is_eof(self, packet):
1314 if not packet.is_eof_packet():
1315 return False
1316 # TODO: Support CLIENT.DEPRECATE_EOF
1317 # 1) Add DEPRECATE_EOF to CAPABILITIES
1318 # 2) Mask CAPABILITIES with server_capabilities
1319 # 3) if server_capabilities & CLIENT.DEPRECATE_EOF:
1320 # use OKPacketWrapper instead of EOFPacketWrapper
1321 wp = EOFPacketWrapper(packet)
1322 self.warning_count = wp.warning_count
1323 self.has_next = wp.has_next
1324 return True
1326 def _read_result_packet(self, first_packet):
1327 self.field_count = first_packet.read_length_encoded_integer()
1328 self._get_descriptions()
1329 self._read_rowdata_packet()
1331 def _read_rowdata_packet_unbuffered(self):
1332 # Check if in an active query
1333 if not self.unbuffered_active:
1334 return
1336 # EOF
1337 packet = self.connection._read_packet()
1338 if self._check_packet_is_eof(packet):
1339 self.unbuffered_active = False
1340 self.connection = None
1341 self.rows = None
1342 return
1344 row = self._read_row_from_packet(packet)
1345 self.affected_rows = 1
1346 self.rows = (row,) # rows should tuple of row for MySQL-python compatibility.
1347 return row
1349 def _finish_unbuffered_query(self):
1350 # After much reading on the MySQL protocol, it appears that there is,
1351 # in fact, no way to stop MySQL from sending all the data after
1352 # executing a query, so we just spin, and wait for an EOF packet.
1353 while self.unbuffered_active:
1354 try:
1355 packet = self.connection._read_packet()
1356 except err.OperationalError as e:
1357 if e.args[0] in (
1358 ER.QUERY_TIMEOUT,
1359 ER.STATEMENT_TIMEOUT,
1360 ):
1361 # if the query timed out we can simply ignore this error
1362 self.unbuffered_active = False
1363 self.connection = None
1364 return
1366 raise
1368 if self._check_packet_is_eof(packet):
1369 self.unbuffered_active = False
1370 self.connection = None # release reference to kill cyclic reference.
1372 def _read_rowdata_packet(self):
1373 """Read a rowdata packet for each data row in the result set."""
1374 rows = []
1375 while True:
1376 packet = self.connection._read_packet()
1377 if self._check_packet_is_eof(packet):
1378 self.connection = None # release reference to kill cyclic reference.
1379 break
1380 rows.append(self._read_row_from_packet(packet))
1382 self.affected_rows = len(rows)
1383 self.rows = tuple(rows)
1385 def _read_row_from_packet(self, packet):
1386 row = []
1387 for encoding, converter in self.converters:
1388 try:
1389 data = packet.read_length_coded_string()
1390 except IndexError:
1391 # No more columns in this row
1392 # See https://github.com/PyMySQL/PyMySQL/pull/434
1393 break
1394 if data is not None:
1395 if encoding is not None:
1396 data = data.decode(encoding)
1397 if DEBUG:
1398 print("DEBUG: DATA = ", data)
1399 if converter is not None:
1400 data = converter(data)
1401 row.append(data)
1402 return tuple(row)
1404 def _get_descriptions(self):
1405 """Read a column descriptor packet for each column in the result."""
1406 self.fields = []
1407 self.converters = []
1408 use_unicode = self.connection.use_unicode
1409 conn_encoding = self.connection.encoding
1410 description = []
1412 for i in range(self.field_count):
1413 field = self.connection._read_packet(FieldDescriptorPacket)
1414 self.fields.append(field)
1415 description.append(field.description())
1416 field_type = field.type_code
1417 if use_unicode:
1418 if field_type == FIELD_TYPE.JSON:
1419 # When SELECT from JSON column: charset = binary
1420 # When SELECT CAST(... AS JSON): charset = connection encoding
1421 # This behavior is different from TEXT / BLOB.
1422 # We should decode result by connection encoding regardless charsetnr.
1423 # See https://github.com/PyMySQL/PyMySQL/issues/488
1424 encoding = conn_encoding # SELECT CAST(... AS JSON)
1425 elif field_type in TEXT_TYPES:
1426 if field.charsetnr == 63: # binary
1427 # TEXTs with charset=binary means BINARY types.
1428 encoding = None
1429 else:
1430 encoding = conn_encoding
1431 else:
1432 # Integers, Dates and Times, and other basic data is encoded in ascii
1433 encoding = "ascii"
1434 else:
1435 encoding = None
1436 converter = self.connection.decoders.get(field_type)
1437 if converter is converters.through:
1438 converter = None
1439 if DEBUG:
1440 print(f"DEBUG: field={field}, converter={converter}")
1441 self.converters.append((encoding, converter))
1443 eof_packet = self.connection._read_packet()
1444 assert eof_packet.is_eof_packet(), "Protocol error, expecting EOF"
1445 self.description = tuple(description)
1448class LoadLocalFile:
1449 def __init__(self, filename, connection):
1450 self.filename = filename
1451 self.connection = connection
1453 def send_data(self):
1454 """Send data packets from the local file to the server"""
1455 if not self.connection._sock:
1456 raise err.InterfaceError(0, "")
1457 conn: Connection = self.connection
1459 try:
1460 with open(self.filename, "rb") as open_file:
1461 packet_size = min(
1462 conn.max_allowed_packet, 16 * 1024
1463 ) # 16KB is efficient enough
1464 while True:
1465 chunk = open_file.read(packet_size)
1466 if not chunk:
1467 break
1468 conn.write_packet(chunk)
1469 except OSError:
1470 raise err.OperationalError(
1471 ER.FILE_NOT_FOUND,
1472 f"Can't find file '{self.filename}'",
1473 )
1474 finally:
1475 if not conn._closed:
1476 # send the empty packet to signify we are done sending data
1477 conn.write_packet(b"")