Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/pymysql/connections.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

856 statements  

1# Python implementation of the MySQL client-server protocol 

2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html 

3# Error codes: 

4# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html 

5import errno 

6import os 

7import socket 

8import struct 

9import sys 

10import traceback 

11import warnings 

12 

13from . import _auth 

14 

15from .charset import charset_by_name, charset_by_id 

16from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS 

17from . import converters 

18from .cursors import Cursor 

19from .optionfile import Parser 

20from .protocol import ( 

21 dump_packet, 

22 MysqlPacket, 

23 FieldDescriptorPacket, 

24 OKPacketWrapper, 

25 EOFPacketWrapper, 

26 LoadLocalPacketWrapper, 

27) 

28from . import err, VERSION_STRING 

29 

30try: 

31 import ssl 

32 

33 SSL_ENABLED = True 

34except ImportError: 

35 ssl = None 

36 SSL_ENABLED = False 

37 

38try: 

39 import getpass 

40 

41 DEFAULT_USER = getpass.getuser() 

42 del getpass 

43except (ImportError, KeyError, OSError): 

44 # When there's no entry in OS database for a current user: 

45 # KeyError is raised in Python 3.12 and below. 

46 # OSError is raised in Python 3.13+ 

47 DEFAULT_USER = None 

48 

49DEBUG = False 

50_DEFAULT_AUTH_PLUGIN = None # if this is not None, use it instead of server's default. 

51 

52TEXT_TYPES = { 

53 FIELD_TYPE.BIT, 

54 FIELD_TYPE.BLOB, 

55 FIELD_TYPE.LONG_BLOB, 

56 FIELD_TYPE.MEDIUM_BLOB, 

57 FIELD_TYPE.STRING, 

58 FIELD_TYPE.TINY_BLOB, 

59 FIELD_TYPE.VAR_STRING, 

60 FIELD_TYPE.VARCHAR, 

61 FIELD_TYPE.GEOMETRY, 

62} 

63 

64 

65DEFAULT_CHARSET = "utf8mb4" 

66 

67MAX_PACKET_LEN = 2**24 - 1 

68 

69 

70def _pack_int24(n): 

71 return struct.pack("<I", n)[:3] 

72 

73 

74# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger 

75def _lenenc_int(i): 

76 if i < 0: 

77 raise ValueError( 

78 "Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i 

79 ) 

80 elif i < 0xFB: 

81 return bytes([i]) 

82 elif i < (1 << 16): 

83 return b"\xfc" + struct.pack("<H", i) 

84 elif i < (1 << 24): 

85 return b"\xfd" + struct.pack("<I", i)[:3] 

86 elif i < (1 << 64): 

87 return b"\xfe" + struct.pack("<Q", i) 

88 else: 

89 raise ValueError( 

90 f"Encoding {i:x} is larger than {1 << 64:x} - no representation in LengthEncodedInteger" 

91 ) 

92 

93 

94class Connection: 

95 """ 

96 Representation of a socket with a mysql server. 

97 

98 The proper way to get an instance of this class is to call 

99 connect(). 

100 

101 Establish a connection to the MySQL database. Accepts several 

102 arguments: 

103 

104 :param host: Host where the database server is located. 

105 :param user: Username to log in as. 

106 :param password: Password to use. 

107 :param database: Database to use, None to not use a particular one. 

108 :param port: MySQL port to use, default is usually OK. (default: 3306) 

109 :param bind_address: When the client has multiple network interfaces, specify 

110 the interface from which to connect to the host. Argument can be 

111 a hostname or an IP address. 

112 :param unix_socket: Use a unix socket rather than TCP/IP. 

113 :param read_timeout: The timeout for reading from the connection in seconds. 

114 (default: None - no timeout) 

115 :param write_timeout: The timeout for writing to the connection in seconds. 

116 (default: None - no timeout) 

117 :param str charset: Charset to use. 

118 :param str collation: Collation name to use. 

119 :param sql_mode: Default SQL_MODE to use. 

120 :param read_default_file: 

121 Specifies my.cnf file to read these parameters from under the [client] section. 

122 :param conv: 

123 Conversion dictionary to use instead of the default one. 

124 This is used to provide custom marshalling and unmarshalling of types. 

125 See converters. 

126 :param use_unicode: 

127 Whether or not to default to unicode strings. 

128 This option defaults to true. 

129 :param client_flag: Custom flags to send to MySQL. Find potential values in constants.CLIENT. 

130 :param cursorclass: Custom cursor class to use. 

131 :param init_command: Initial SQL statement to run when connection is established. 

132 :param connect_timeout: The timeout for connecting to the database in seconds. 

133 (default: 10, min: 1, max: 31536000) 

134 :param ssl: An ssl.SSLContext, or a dict of arguments similar to mysql_ssl_set()'s parameters. 

135 Passing a dict is deprecated; use the individual ``ssl_*`` parameters or an 

136 ``ssl.SSLContext`` instead. 

137 :param ssl_ca: Path to the file that contains a PEM-formatted CA certificate. 

138 :param ssl_cert: Path to the file that contains a PEM-formatted client certificate. 

139 :param ssl_disabled: A boolean value that disables usage of TLS. Unlike other SSL options, 

140 setting this to True explicitly prohibits the use of TLS, even if the server supports it. 

141 :param ssl_key: Path to the file that contains a PEM-formatted private key for 

142 the client certificate. 

143 :param ssl_key_password: The password for the client certificate private key. 

144 :param ssl_verify_cert: Set to true to check the server certificate's validity. 

145 :param ssl_verify_identity: Set to true to check the server's identity. 

146 :param read_default_group: Group to read from in the configuration file. 

147 :param autocommit: Autocommit mode. None means use server default. (default: False) 

148 :param local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False) 

149 :param max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB) 

150 Only used to limit size of "LOAD LOCAL INFILE" data packet smaller than default (16KB). 

151 :param defer_connect: Don't explicitly connect on construction - wait for connect call. 

152 (default: False) 

153 :param auth_plugin_map: A dict of plugin names to a class that processes that plugin. 

154 The class will take the Connection object as the argument to the constructor. 

155 The class needs an authenticate method taking an authentication packet as 

156 an argument. For the dialog plugin, a prompt(echo, prompt) method can be used 

157 (if no authenticate method) for returning a string from the user. (experimental) 

158 :param server_public_key: SHA256 authentication plugin public key value. (default: None) 

159 :param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False) 

160 :param compress: Not supported. 

161 :param named_pipe: Not supported. 

162 :param db: **DEPRECATED** Alias for database. 

163 :param passwd: **DEPRECATED** Alias for password. 

164 

165 See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_ in the 

166 specification. 

167 """ 

168 

169 _sock = None 

170 _rfile = None 

171 _auth_plugin_name = "" 

172 _closed = False 

173 _secure = False 

174 

175 def __init__( 

176 self, 

177 *, 

178 user=None, # The first four arguments is based on DB-API 2.0 recommendation. 

179 password="", 

180 host=None, 

181 database=None, 

182 unix_socket=None, 

183 port=0, 

184 charset="", 

185 collation=None, 

186 sql_mode=None, 

187 read_default_file=None, 

188 conv=None, 

189 use_unicode=True, 

190 client_flag=0, 

191 cursorclass=Cursor, 

192 init_command=None, 

193 connect_timeout=10, 

194 read_default_group=None, 

195 autocommit=False, 

196 local_infile=False, 

197 max_allowed_packet=16 * 1024 * 1024, 

198 defer_connect=False, 

199 auth_plugin_map=None, 

200 read_timeout=None, 

201 write_timeout=None, 

202 bind_address=None, 

203 binary_prefix=False, 

204 program_name=None, 

205 server_public_key=None, 

206 ssl=None, 

207 ssl_ca=None, 

208 ssl_cert=None, 

209 ssl_disabled=None, 

210 ssl_key=None, 

211 ssl_key_password=None, 

212 ssl_verify_cert=None, 

213 ssl_verify_identity=None, 

214 compress=None, # not supported 

215 named_pipe=None, # not supported 

216 passwd=None, # deprecated 

217 db=None, # deprecated 

218 ): 

219 if db is not None and database is None: 

220 warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3) 

221 database = db 

222 if passwd is not None and not password: 

223 warnings.warn( 

224 "'passwd' is deprecated, use 'password'", DeprecationWarning, 3 

225 ) 

226 password = passwd 

227 

228 if compress or named_pipe: 

229 raise NotImplementedError( 

230 "compress and named_pipe arguments are not supported" 

231 ) 

232 

233 self._local_infile = bool(local_infile) 

234 if self._local_infile: 

235 client_flag |= CLIENT.LOCAL_FILES 

236 

237 if read_default_group and not read_default_file: 

238 if sys.platform.startswith("win"): 

239 read_default_file = "c:\\my.ini" 

240 else: 

241 read_default_file = "/etc/my.cnf" 

242 

243 if read_default_file: 

244 if not read_default_group: 

245 read_default_group = "client" 

246 

247 cfg = Parser() 

248 cfg.read(os.path.expanduser(read_default_file)) 

249 

250 def _config(key, arg): 

251 if arg: 

252 return arg 

253 try: 

254 return cfg.get(read_default_group, key) 

255 except Exception: 

256 return arg 

257 

258 user = _config("user", user) 

259 password = _config("password", password) 

260 host = _config("host", host) 

261 database = _config("database", database) 

262 unix_socket = _config("socket", unix_socket) 

263 port = int(_config("port", port)) 

264 bind_address = _config("bind-address", bind_address) 

265 charset = _config("default-character-set", charset) 

266 if not ssl: 

267 ssl = {} 

268 if isinstance(ssl, dict): 

269 for key in ["ca", "capath", "cert", "key", "password", "cipher"]: 

270 value = _config("ssl-" + key, ssl.get(key)) 

271 if value: 

272 ssl[key] = value 

273 

274 self.ssl = False 

275 self._ssl_required = False 

276 if not ssl_disabled: 

277 if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity: 

278 ssl = { 

279 "ca": ssl_ca, 

280 "check_hostname": bool(ssl_verify_identity), 

281 "verify_mode": ssl_verify_cert 

282 if ssl_verify_cert is not None 

283 else False, 

284 } 

285 if ssl_cert is not None: 

286 ssl["cert"] = ssl_cert 

287 if ssl_key is not None: 

288 ssl["key"] = ssl_key 

289 if ssl_key_password is not None: 

290 ssl["password"] = ssl_key_password 

291 if ssl: 

292 if not SSL_ENABLED: 

293 raise NotImplementedError("ssl module not found") 

294 self.ssl = True 

295 self._ssl_required = True 

296 client_flag |= CLIENT.SSL 

297 self.ctx = self._create_ssl_ctx(ssl) 

298 elif SSL_ENABLED: 

299 # No explicit SSL options specified: use PREFERRED mode. 

300 # Attempt SSL but fall back gracefully if the server doesn't support it. 

301 self.ssl = True 

302 self._ssl_required = False 

303 self.ctx = self._create_ssl_ctx({}) 

304 

305 self.host = host or "localhost" 

306 self.port = port or 3306 

307 if type(self.port) is not int: 

308 raise ValueError("port should be of type int") 

309 self.user = user or DEFAULT_USER 

310 self.password = password or b"" 

311 if isinstance(self.password, str): 

312 self.password = self.password.encode("latin1") 

313 self.db = database 

314 self.unix_socket = unix_socket 

315 self.bind_address = bind_address 

316 if not (0 < connect_timeout <= 31536000): 

317 raise ValueError("connect_timeout should be >0 and <=31536000") 

318 self.connect_timeout = connect_timeout or None 

319 if read_timeout is not None and read_timeout <= 0: 

320 raise ValueError("read_timeout should be > 0") 

321 self._read_timeout = read_timeout 

322 if write_timeout is not None and write_timeout <= 0: 

323 raise ValueError("write_timeout should be > 0") 

324 self._write_timeout = write_timeout 

325 

326 self.charset = charset or DEFAULT_CHARSET 

327 self.collation = collation 

328 self.use_unicode = use_unicode 

329 

330 self.encoding = charset_by_name(self.charset).encoding 

331 

332 client_flag |= CLIENT.CAPABILITIES 

333 if self.db: 

334 client_flag |= CLIENT.CONNECT_WITH_DB 

335 

336 self.client_flag = client_flag 

337 

338 self.cursorclass = cursorclass 

339 

340 self._result = None 

341 self._affected_rows = 0 

342 self.host_info = "Not connected" 

343 

344 # specified autocommit mode. None means use server default. 

345 self.autocommit_mode = autocommit 

346 

347 if conv is None: 

348 conv = converters.conversions 

349 

350 # Need for MySQLdb compatibility. 

351 self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int} 

352 self.decoders = {k: v for (k, v) in conv.items() if type(k) is int} 

353 self.sql_mode = sql_mode 

354 self.init_command = init_command 

355 self.max_allowed_packet = max_allowed_packet 

356 self._auth_plugin_map = auth_plugin_map or {} 

357 self._binary_prefix = binary_prefix 

358 self.server_public_key = server_public_key 

359 

360 self._connect_attrs = { 

361 "_client_name": "pymysql", 

362 "_client_version": VERSION_STRING, 

363 "_pid": str(os.getpid()), 

364 } 

365 

366 if program_name: 

367 self._connect_attrs["program_name"] = program_name 

368 

369 if defer_connect: 

370 self._sock = None 

371 else: 

372 self.connect() 

373 

374 def __enter__(self): 

375 return self 

376 

377 def __exit__(self, *exc_info): 

378 del exc_info 

379 self.close() 

380 

381 def _create_ssl_ctx(self, sslp): 

382 if isinstance(sslp, ssl.SSLContext): 

383 return sslp 

384 ca = sslp.get("ca") 

385 capath = sslp.get("capath") 

386 hasnoca = ca is None and capath is None 

387 ctx = ssl.create_default_context(cafile=ca, capath=capath) 

388 

389 # Python 3.13 enables VERIFY_X509_STRICT by default. 

390 # But self signed certificates that are generated by MySQL automatically 

391 # doesn't pass the verification. 

392 ctx.verify_flags &= ~ssl.VERIFY_X509_STRICT 

393 

394 ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True) 

395 verify_mode_value = sslp.get("verify_mode") 

396 if verify_mode_value is None: 

397 ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED 

398 elif isinstance(verify_mode_value, bool): 

399 ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE 

400 else: 

401 if isinstance(verify_mode_value, str): 

402 verify_mode_value = verify_mode_value.lower() 

403 if verify_mode_value in ("none", "0", "false", "no"): 

404 ctx.verify_mode = ssl.CERT_NONE 

405 elif verify_mode_value == "optional": 

406 ctx.verify_mode = ssl.CERT_OPTIONAL 

407 elif verify_mode_value in ("required", "1", "true", "yes"): 

408 ctx.verify_mode = ssl.CERT_REQUIRED 

409 else: 

410 ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED 

411 if "cert" in sslp: 

412 ctx.load_cert_chain( 

413 sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password") 

414 ) 

415 if "cipher" in sslp: 

416 ctx.set_ciphers(sslp["cipher"]) 

417 ctx.options |= ssl.OP_NO_SSLv2 

418 ctx.options |= ssl.OP_NO_SSLv3 

419 return ctx 

420 

421 def close(self): 

422 """ 

423 Send the quit message and close the socket. 

424 

425 See `Connection.close() <https://www.python.org/dev/peps/pep-0249/#Connection.close>`_ 

426 in the specification. 

427 

428 :raise Error: If the connection is already closed. 

429 """ 

430 if self._closed: 

431 raise err.Error("Already closed") 

432 self._closed = True 

433 if self._sock is None: 

434 return 

435 send_data = struct.pack("<iB", 1, COMMAND.COM_QUIT) 

436 try: 

437 self._write_bytes(send_data) 

438 except Exception: 

439 pass 

440 finally: 

441 self._force_close() 

442 

443 @property 

444 def open(self): 

445 """Return True if the connection is open.""" 

446 return self._sock is not None 

447 

448 def _force_close(self): 

449 """Close connection without QUIT message.""" 

450 if self._rfile: 

451 self._rfile.close() 

452 if self._sock: 

453 try: 

454 self._sock.close() 

455 except: # noqa 

456 pass 

457 self._sock = None 

458 self._rfile = None 

459 

460 __del__ = _force_close 

461 

462 def autocommit(self, value): 

463 self.autocommit_mode = bool(value) 

464 current = self.get_autocommit() 

465 if value != current: 

466 self._send_autocommit_mode() 

467 

468 def get_autocommit(self): 

469 return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT) 

470 

471 def _read_ok_packet(self): 

472 pkt = self._read_packet() 

473 if not pkt.is_ok_packet(): 

474 raise err.OperationalError( 

475 CR.CR_COMMANDS_OUT_OF_SYNC, 

476 "Command Out of Sync", 

477 ) 

478 ok = OKPacketWrapper(pkt) 

479 self.server_status = ok.server_status 

480 return ok 

481 

482 def _send_autocommit_mode(self): 

483 """Set whether or not to commit after every execute().""" 

484 self._execute_command( 

485 COMMAND.COM_QUERY, "SET AUTOCOMMIT = %s" % self.escape(self.autocommit_mode) 

486 ) 

487 self._read_ok_packet() 

488 

489 def begin(self): 

490 """Begin transaction.""" 

491 self._execute_command(COMMAND.COM_QUERY, "BEGIN") 

492 self._read_ok_packet() 

493 

494 def commit(self): 

495 """ 

496 Commit changes to stable storage. 

497 

498 See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_ 

499 in the specification. 

500 """ 

501 self._execute_command(COMMAND.COM_QUERY, "COMMIT") 

502 self._read_ok_packet() 

503 

504 def rollback(self): 

505 """ 

506 Roll back the current transaction. 

507 

508 See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_ 

509 in the specification. 

510 """ 

511 self._execute_command(COMMAND.COM_QUERY, "ROLLBACK") 

512 self._read_ok_packet() 

513 

514 def show_warnings(self): 

515 """Send the "SHOW WARNINGS" SQL command.""" 

516 self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS") 

517 result = MySQLResult(self) 

518 result.read() 

519 return result.rows 

520 

521 def select_db(self, db): 

522 """ 

523 Set current db. 

524 

525 :param db: The name of the db. 

526 """ 

527 self._execute_command(COMMAND.COM_INIT_DB, db) 

528 self._read_ok_packet() 

529 

530 def escape(self, obj, mapping=None): 

531 """Escape whatever value is passed. 

532 

533 Non-standard, for internal use; do not use this in your applications. 

534 """ 

535 if isinstance(obj, str): 

536 return "'" + self.escape_string(obj) + "'" 

537 if isinstance(obj, (bytes, bytearray)): 

538 ret = self._quote_bytes(obj) 

539 if self._binary_prefix: 

540 ret = "_binary" + ret 

541 return ret 

542 return converters.escape_item(obj, self.charset, mapping=mapping) 

543 

544 def literal(self, obj): 

545 """Alias for escape(). 

546 

547 Non-standard, for internal use; do not use this in your applications. 

548 """ 

549 return self.escape(obj, self.encoders) 

550 

551 def escape_string(self, s): 

552 if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES: 

553 return s.replace("'", "''") 

554 return converters.escape_string(s) 

555 

556 def _quote_bytes(self, s): 

557 if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES: 

558 return "'{}'".format( 

559 s.replace(b"'", b"''").decode("ascii", "surrogateescape") 

560 ) 

561 return converters.escape_bytes(s) 

562 

563 def cursor(self, cursor=None): 

564 """ 

565 Create a new cursor to execute queries with. 

566 

567 :param cursor: The type of cursor to create. None means use Cursor. 

568 :type cursor: :py:class:`Cursor`, :py:class:`SSCursor`, :py:class:`DictCursor`, 

569 or :py:class:`SSDictCursor`. 

570 """ 

571 if cursor: 

572 return cursor(self) 

573 return self.cursorclass(self) 

574 

575 # The following methods are INTERNAL USE ONLY (called from Cursor) 

576 def query(self, sql, unbuffered=False): 

577 # if DEBUG: 

578 # print("DEBUG: sending query:", sql) 

579 if isinstance(sql, str): 

580 sql = sql.encode(self.encoding, "surrogateescape") 

581 self._execute_command(COMMAND.COM_QUERY, sql) 

582 self._affected_rows = self._read_query_result(unbuffered=unbuffered) 

583 return self._affected_rows 

584 

585 def next_result(self, unbuffered=False): 

586 self._affected_rows = self._read_query_result(unbuffered=unbuffered) 

587 return self._affected_rows 

588 

589 def affected_rows(self): 

590 return self._affected_rows 

591 

592 def kill(self, thread_id): 

593 if not isinstance(thread_id, int): 

594 raise TypeError("thread_id must be an integer") 

595 self.query(f"KILL {thread_id:d}") 

596 

597 def ping(self, reconnect=False): 

598 """ 

599 Check if the server is alive. 

600 

601 `reconnect` is deprecated. Create a new connection if you want to reconnect. 

602 

603 :param reconnect: If the connection is closed, reconnect. 

604 :type reconnect: boolean 

605 

606 :raise Error: If the connection is closed and reconnect=False. 

607 """ 

608 # emit deprecation warning for reconnect. 

609 if reconnect: 

610 warnings.warn( 

611 "The 'reconnect' argument is deprecated. Create a new connection if you want to reconnect.", 

612 DeprecationWarning, 

613 2, 

614 ) 

615 if self._sock is None: 

616 if reconnect: 

617 self.connect() 

618 reconnect = False 

619 else: 

620 raise err.Error("Already closed") 

621 try: 

622 self._execute_command(COMMAND.COM_PING, "") 

623 self._read_ok_packet() 

624 except Exception: 

625 if reconnect: 

626 self.connect() 

627 self.ping(False) 

628 else: 

629 raise 

630 

631 def set_charset(self, charset): 

632 """Deprecated. Use set_character_set() instead.""" 

633 warnings.warn( 

634 "'set_charset' is deprecated, use 'set_character_set' instead", 

635 DeprecationWarning, 

636 2, 

637 ) 

638 # This function has been implemented in old PyMySQL. 

639 # But this name is different from MySQLdb. 

640 # So we keep this function for compatibility and add 

641 # new set_character_set() function. 

642 self.set_character_set(charset) 

643 

644 def set_character_set(self, charset, collation=None): 

645 """ 

646 Set charaset (and collation) 

647 

648 Send "SET NAMES charset [COLLATE collation]" query. 

649 Update Connection.encoding based on charset. 

650 """ 

651 # Make sure charset is supported. 

652 encoding = charset_by_name(charset).encoding 

653 

654 if collation: 

655 query = f"SET NAMES {charset} COLLATE {collation}" 

656 else: 

657 query = f"SET NAMES {charset}" 

658 self._execute_command(COMMAND.COM_QUERY, query) 

659 self._read_packet() 

660 self.charset = charset 

661 self.encoding = encoding 

662 self.collation = collation 

663 

664 def connect(self, sock=None): 

665 self._closed = False 

666 try: 

667 if sock is None: 

668 if self.unix_socket: 

669 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

670 sock.settimeout(self.connect_timeout) 

671 sock.connect(self.unix_socket) 

672 self.host_info = "Localhost via UNIX socket" 

673 self._secure = True 

674 if DEBUG: 

675 print("connected using unix_socket") 

676 else: 

677 kwargs = {} 

678 if self.bind_address is not None: 

679 kwargs["source_address"] = (self.bind_address, 0) 

680 while True: 

681 try: 

682 sock = socket.create_connection( 

683 (self.host, self.port), self.connect_timeout, **kwargs 

684 ) 

685 break 

686 except OSError as e: 

687 if e.errno == errno.EINTR: 

688 continue 

689 raise 

690 self.host_info = "socket %s:%d" % (self.host, self.port) 

691 if DEBUG: 

692 print("connected using socket") 

693 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

694 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

695 sock.settimeout(None) 

696 

697 self._sock = sock 

698 self._rfile = sock.makefile("rb") 

699 self._next_seq_id = 0 

700 

701 self._get_server_information() 

702 self._request_authentication() 

703 

704 # Send "SET NAMES" query on init for: 

705 # - Ensure charaset (and collation) is set to the server. 

706 # - collation_id in handshake packet may be ignored. 

707 # - If collation is not specified, we don't know what is server's 

708 # default collation for the charset. For example, default collation 

709 # of utf8mb4 is: 

710 # - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci 

711 # - MySQL 8.0: utf8mb4_0900_ai_ci 

712 # 

713 # Reference: 

714 # - https://github.com/PyMySQL/PyMySQL/issues/1092 

715 # - https://github.com/wagtail/wagtail/issues/9477 

716 # - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese) 

717 self.set_character_set(self.charset, self.collation) 

718 

719 if self.sql_mode is not None: 

720 c = self.cursor() 

721 c.execute("SET sql_mode=%s", (self.sql_mode,)) 

722 c.close() 

723 

724 if self.init_command is not None: 

725 c = self.cursor() 

726 c.execute(self.init_command) 

727 c.close() 

728 

729 if self.autocommit_mode is not None: 

730 self.autocommit(self.autocommit_mode) 

731 except BaseException as e: 

732 self._force_close() 

733 

734 if isinstance(e, (OSError, IOError)): 

735 exc = err.OperationalError( 

736 CR.CR_CONN_HOST_ERROR, 

737 f"Can't connect to MySQL server on {self.host!r} ({e})", 

738 ) 

739 # Keep original exception and traceback to investigate error. 

740 exc.original_exception = e 

741 exc.traceback = traceback.format_exc() 

742 if DEBUG: 

743 print(exc.traceback) 

744 raise exc 

745 

746 # If e is neither DatabaseError or IOError, It's a bug. 

747 # But raising AssertionError hides original error. 

748 # So just reraise it. 

749 raise 

750 

751 def write_packet(self, payload): 

752 """Writes an entire "mysql packet" in its entirety to the network 

753 adding its length and sequence number. 

754 """ 

755 # Internal note: when you build packet manually and calls _write_bytes() 

756 # directly, you should set self._next_seq_id properly. 

757 data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload 

758 if DEBUG: 

759 dump_packet(data) 

760 self._write_bytes(data) 

761 self._next_seq_id = (self._next_seq_id + 1) % 256 

762 

763 def _read_packet(self, packet_type=MysqlPacket): 

764 """Read an entire "mysql packet" in its entirety from the network 

765 and return a MysqlPacket type that represents the results. 

766 

767 :raise OperationalError: If the connection to the MySQL server is lost. 

768 :raise InternalError: If the packet sequence number is wrong. 

769 """ 

770 buff = bytearray() 

771 while True: 

772 packet_header = self._read_bytes(4) 

773 # if DEBUG: dump_packet(packet_header) 

774 

775 btrl, btrh, packet_number = struct.unpack("<HBB", packet_header) 

776 bytes_to_read = btrl + (btrh << 16) 

777 if packet_number != self._next_seq_id: 

778 self._force_close() 

779 if packet_number == 0: 

780 # MariaDB sends error packet with seqno==0 when shutdown 

781 raise err.OperationalError( 

782 CR.CR_SERVER_LOST, 

783 "Lost connection to MySQL server during query", 

784 ) 

785 raise err.InternalError( 

786 "Packet sequence number wrong - got %d expected %d" 

787 % (packet_number, self._next_seq_id) 

788 ) 

789 self._next_seq_id = (self._next_seq_id + 1) % 256 

790 

791 recv_data = self._read_bytes(bytes_to_read) 

792 if DEBUG: 

793 dump_packet(recv_data) 

794 buff += recv_data 

795 # https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html 

796 if bytes_to_read < MAX_PACKET_LEN: 

797 break 

798 

799 packet = packet_type(bytes(buff), self.encoding) 

800 if packet.is_error_packet(): 

801 if self._result is not None and self._result.unbuffered_active is True: 

802 self._result.unbuffered_active = False 

803 packet.raise_for_error() 

804 return packet 

805 

806 def _read_bytes(self, num_bytes): 

807 self._sock.settimeout(self._read_timeout) 

808 while True: 

809 try: 

810 data = self._rfile.read(num_bytes) 

811 break 

812 except OSError as e: 

813 if e.errno == errno.EINTR: 

814 continue 

815 self._force_close() 

816 raise err.OperationalError( 

817 CR.CR_SERVER_LOST, 

818 f"Lost connection to MySQL server during query ({e})", 

819 ) 

820 except BaseException: 

821 # Don't convert unknown exception to MySQLError. 

822 self._force_close() 

823 raise 

824 if len(data) < num_bytes: 

825 self._force_close() 

826 raise err.OperationalError( 

827 CR.CR_SERVER_LOST, "Lost connection to MySQL server during query" 

828 ) 

829 return data 

830 

831 def _write_bytes(self, data): 

832 self._sock.settimeout(self._write_timeout) 

833 try: 

834 self._sock.sendall(data) 

835 except OSError as e: 

836 self._force_close() 

837 raise err.OperationalError( 

838 CR.CR_SERVER_GONE_ERROR, f"MySQL server has gone away ({e!r})" 

839 ) 

840 

841 def _read_query_result(self, unbuffered=False): 

842 self._result = None 

843 result = MySQLResult(self) 

844 if unbuffered: 

845 result.init_unbuffered_query() 

846 else: 

847 result.read() 

848 self._result = result 

849 if result.server_status is not None: 

850 self.server_status = result.server_status 

851 return result.affected_rows 

852 

853 def insert_id(self): 

854 if self._result: 

855 return self._result.insert_id 

856 else: 

857 return 0 

858 

859 def _execute_command(self, command, sql): 

860 """ 

861 :raise InterfaceError: If the connection is closed. 

862 :raise ValueError: If no username was specified. 

863 """ 

864 if not self._sock: 

865 raise err.InterfaceError(0, "") 

866 

867 # If the last query was unbuffered, make sure it finishes before 

868 # sending new commands 

869 if self._result is not None: 

870 if self._result.unbuffered_active: 

871 warnings.warn("Previous unbuffered result was left incomplete") 

872 self._result._finish_unbuffered_query() 

873 while self._result.has_next: 

874 self.next_result() 

875 self._result = None 

876 

877 if isinstance(sql, str): 

878 sql = sql.encode(self.encoding) 

879 

880 packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command 

881 

882 # tiny optimization: build first packet manually instead of 

883 # calling self..write_packet() 

884 prelude = struct.pack("<iB", packet_size, command) 

885 packet = prelude + sql[: packet_size - 1] 

886 self._write_bytes(packet) 

887 if DEBUG: 

888 dump_packet(packet) 

889 self._next_seq_id = 1 

890 

891 if packet_size < MAX_PACKET_LEN: 

892 return 

893 

894 sql = sql[packet_size - 1 :] 

895 while True: 

896 packet_size = min(MAX_PACKET_LEN, len(sql)) 

897 self.write_packet(sql[:packet_size]) 

898 sql = sql[packet_size:] 

899 if not sql and packet_size < MAX_PACKET_LEN: 

900 break 

901 

902 def _request_authentication(self): 

903 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse 

904 if int(self.server_version.split(".", 1)[0]) >= 5: 

905 self.client_flag |= CLIENT.MULTI_RESULTS 

906 

907 if self.user is None: 

908 raise ValueError("Did not specify a username") 

909 

910 charset_id = charset_by_name(self.charset).id 

911 if isinstance(self.user, str): 

912 self.user = self.user.encode(self.encoding) 

913 

914 # Determine flags for the initial handshake packet. 

915 # CLIENT.SSL is added conditionally: for REQUIRED mode it is already set in 

916 # self.client_flag, but for PREFERRED mode it is only added when the server 

917 # also advertises SSL support. 

918 # _do_ssl is set here and checked below for sha256_password auth. 

919 client_flags = self.client_flag 

920 if self.ssl: 

921 if self.server_capabilities & CLIENT.SSL: 

922 # SSL upgrade: include CLIENT.SSL flag and wrap the socket. 

923 _do_ssl = True 

924 client_flags |= CLIENT.SSL 

925 elif self._ssl_required: 

926 raise err.OperationalError( 

927 CR.CR_SSL_CONNECTION_ERROR, 

928 "SSL is required but the server doesn't support it", 

929 ) 

930 else: 

931 # PREFERRED mode: server doesn't support SSL, fall back to non-SSL. 

932 _do_ssl = False 

933 else: 

934 _do_ssl = False 

935 

936 data_init = struct.pack( 

937 "<iIB23s", client_flags, MAX_PACKET_LEN, charset_id, b"" 

938 ) 

939 

940 if _do_ssl: 

941 self.write_packet(data_init) 

942 self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host) 

943 self._rfile = self._sock.makefile("rb") 

944 self._secure = True 

945 

946 data = data_init + self.user + b"\0" 

947 

948 authresp = b"" 

949 plugin_name = None 

950 

951 if self._auth_plugin_name == "": 

952 plugin_name = b"" 

953 authresp = _auth.scramble_native_password(self.password, self.salt) 

954 elif self._auth_plugin_name == "mysql_native_password": 

955 plugin_name = b"mysql_native_password" 

956 authresp = _auth.scramble_native_password(self.password, self.salt) 

957 elif self._auth_plugin_name == "caching_sha2_password": 

958 plugin_name = b"caching_sha2_password" 

959 if self.password: 

960 if DEBUG: 

961 print("caching_sha2: trying fast path") 

962 authresp = _auth.scramble_caching_sha2(self.password, self.salt) 

963 else: 

964 if DEBUG: 

965 print("caching_sha2: empty password") 

966 elif self._auth_plugin_name == "sha256_password": 

967 plugin_name = b"sha256_password" 

968 if _do_ssl: 

969 authresp = self.password + b"\0" 

970 elif self.password: 

971 authresp = b"\1" # request public key 

972 else: 

973 authresp = b"\0" # empty password 

974 

975 if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA: 

976 data += _lenenc_int(len(authresp)) + authresp 

977 elif self.server_capabilities & CLIENT.SECURE_CONNECTION: 

978 data += struct.pack("B", len(authresp)) + authresp 

979 else: # pragma: no cover - not testing against servers without secure auth (>=5.0) 

980 data += authresp + b"\0" 

981 

982 if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB: 

983 if isinstance(self.db, str): 

984 self.db = self.db.encode(self.encoding) 

985 data += self.db + b"\0" 

986 

987 if self.server_capabilities & CLIENT.PLUGIN_AUTH: 

988 data += (plugin_name or b"") + b"\0" 

989 

990 if self.server_capabilities & CLIENT.CONNECT_ATTRS: 

991 connect_attrs = b"" 

992 for k, v in self._connect_attrs.items(): 

993 k = k.encode("utf-8") 

994 connect_attrs += _lenenc_int(len(k)) + k 

995 v = v.encode("utf-8") 

996 connect_attrs += _lenenc_int(len(v)) + v 

997 data += _lenenc_int(len(connect_attrs)) + connect_attrs 

998 

999 self.write_packet(data) 

1000 auth_packet = self._read_packet() 

1001 

1002 # if authentication method isn't accepted the first byte 

1003 # will have the octet 254 

1004 if auth_packet.is_auth_switch_request(): 

1005 if DEBUG: 

1006 print("received auth switch") 

1007 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest 

1008 auth_packet.read_uint8() # 0xfe packet identifier 

1009 plugin_name = auth_packet.read_string() 

1010 if ( 

1011 self.server_capabilities & CLIENT.PLUGIN_AUTH 

1012 and plugin_name is not None 

1013 ): 

1014 auth_packet = self._process_auth(plugin_name, auth_packet) 

1015 else: 

1016 raise err.OperationalError("received unknown auth switch request") 

1017 elif auth_packet.is_extra_auth_data(): 

1018 if DEBUG: 

1019 print("received extra data") 

1020 # https://dev.mysql.com/doc/internals/en/successful-authentication.html 

1021 if self._auth_plugin_name == "caching_sha2_password": 

1022 auth_packet = _auth.caching_sha2_password_auth(self, auth_packet) 

1023 elif self._auth_plugin_name == "sha256_password": 

1024 auth_packet = _auth.sha256_password_auth(self, auth_packet) 

1025 else: 

1026 raise err.OperationalError( 

1027 "Received extra packet for auth method %r", self._auth_plugin_name 

1028 ) 

1029 

1030 if DEBUG: 

1031 print("Succeed to auth") 

1032 

1033 def _process_auth(self, plugin_name, auth_packet): 

1034 handler = self._get_auth_plugin_handler(plugin_name) 

1035 if handler: 

1036 try: 

1037 return handler.authenticate(auth_packet) 

1038 except AttributeError: 

1039 if plugin_name != b"dialog": 

1040 raise err.OperationalError( 

1041 CR.CR_AUTH_PLUGIN_CANNOT_LOAD, 

1042 f"Authentication plugin '{plugin_name}'" 

1043 f" not loaded: - {type(handler)!r} missing authenticate method", 

1044 ) 

1045 if plugin_name == b"caching_sha2_password": 

1046 return _auth.caching_sha2_password_auth(self, auth_packet) 

1047 elif plugin_name == b"sha256_password": 

1048 return _auth.sha256_password_auth(self, auth_packet) 

1049 elif plugin_name == b"mysql_native_password": 

1050 data = _auth.scramble_native_password(self.password, auth_packet.read_all()) 

1051 elif plugin_name == b"client_ed25519": 

1052 data = _auth.ed25519_password(self.password, auth_packet.read_all()) 

1053 elif plugin_name == b"mysql_old_password": 

1054 data = ( 

1055 _auth.scramble_old_password(self.password, auth_packet.read_all()) 

1056 + b"\0" 

1057 ) 

1058 elif plugin_name == b"mysql_clear_password": 

1059 # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html 

1060 data = self.password + b"\0" 

1061 elif plugin_name == b"dialog": 

1062 pkt = auth_packet 

1063 while True: 

1064 flag = pkt.read_uint8() 

1065 echo = (flag & 0x06) == 0x02 

1066 last = (flag & 0x01) == 0x01 

1067 prompt = pkt.read_all() 

1068 

1069 if prompt == b"Password: ": 

1070 self.write_packet(self.password + b"\0") 

1071 elif handler: 

1072 resp = "no response - TypeError within plugin.prompt method" 

1073 try: 

1074 resp = handler.prompt(echo, prompt) 

1075 self.write_packet(resp + b"\0") 

1076 except AttributeError: 

1077 raise err.OperationalError( 

1078 CR.CR_AUTH_PLUGIN_CANNOT_LOAD, 

1079 f"Authentication plugin '{plugin_name}'" 

1080 f" not loaded: - {handler!r} missing prompt method", 

1081 ) 

1082 except TypeError: 

1083 raise err.OperationalError( 

1084 CR.CR_AUTH_PLUGIN_ERR, 

1085 f"Authentication plugin '{plugin_name}'" 

1086 f" {handler!r} didn't respond with string. Returned '{resp!r}' to prompt {prompt!r}", 

1087 ) 

1088 else: 

1089 raise err.OperationalError( 

1090 CR.CR_AUTH_PLUGIN_CANNOT_LOAD, 

1091 f"Authentication plugin '{plugin_name}' not configured", 

1092 ) 

1093 pkt = self._read_packet() 

1094 pkt.check_error() 

1095 if pkt.is_ok_packet() or last: 

1096 break 

1097 return pkt 

1098 else: 

1099 raise err.OperationalError( 

1100 CR.CR_AUTH_PLUGIN_CANNOT_LOAD, 

1101 "Authentication plugin '%s' not configured" % plugin_name, 

1102 ) 

1103 

1104 self.write_packet(data) 

1105 pkt = self._read_packet() 

1106 pkt.check_error() 

1107 return pkt 

1108 

1109 def _get_auth_plugin_handler(self, plugin_name): 

1110 plugin_class = self._auth_plugin_map.get(plugin_name) 

1111 if not plugin_class and isinstance(plugin_name, bytes): 

1112 plugin_class = self._auth_plugin_map.get(plugin_name.decode("ascii")) 

1113 if plugin_class: 

1114 try: 

1115 handler = plugin_class(self) 

1116 except TypeError: 

1117 raise err.OperationalError( 

1118 CR.CR_AUTH_PLUGIN_CANNOT_LOAD, 

1119 f"Authentication plugin '{plugin_name}'" 

1120 f" not loaded: - {plugin_class!r} cannot be constructed with connection object", 

1121 ) 

1122 else: 

1123 handler = None 

1124 return handler 

1125 

1126 # _mysql support 

1127 def thread_id(self): 

1128 return self.server_thread_id[0] 

1129 

1130 def character_set_name(self): 

1131 return self.charset 

1132 

1133 def get_host_info(self): 

1134 return self.host_info 

1135 

1136 def get_proto_info(self): 

1137 return self.protocol_version 

1138 

1139 def _get_server_information(self): 

1140 i = 0 

1141 packet = self._read_packet() 

1142 data = packet.get_all_data() 

1143 

1144 self.protocol_version = data[i] 

1145 i += 1 

1146 

1147 server_end = data.find(b"\0", i) 

1148 self.server_version = data[i:server_end].decode("latin1") 

1149 i = server_end + 1 

1150 

1151 self.server_thread_id = struct.unpack("<I", data[i : i + 4]) 

1152 i += 4 

1153 

1154 self.salt = data[i : i + 8] 

1155 i += 9 # 8 + 1(filler) 

1156 

1157 self.server_capabilities = struct.unpack("<H", data[i : i + 2])[0] 

1158 i += 2 

1159 

1160 if len(data) >= i + 6: 

1161 lang, stat, cap_h, salt_len = struct.unpack("<BHHB", data[i : i + 6]) 

1162 i += 6 

1163 # TODO: deprecate server_language and server_charset. 

1164 # mysqlclient-python doesn't provide it. 

1165 self.server_language = lang 

1166 try: 

1167 self.server_charset = charset_by_id(lang).name 

1168 except KeyError: 

1169 # unknown collation 

1170 self.server_charset = None 

1171 

1172 self.server_status = stat 

1173 if DEBUG: 

1174 print("server_status: %x" % stat) 

1175 

1176 self.server_capabilities |= cap_h << 16 

1177 if DEBUG: 

1178 print("salt_len:", salt_len) 

1179 salt_len = max(12, salt_len - 9) 

1180 

1181 # reserved 

1182 i += 10 

1183 

1184 if len(data) >= i + salt_len: 

1185 # salt_len includes auth_plugin_data_part_1 and filler 

1186 self.salt += data[i : i + salt_len] 

1187 i += salt_len 

1188 

1189 i += 1 

1190 # AUTH PLUGIN NAME may appear here. 

1191 if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i: 

1192 # Due to Bug#59453 the auth-plugin-name is missing the terminating 

1193 # NUL-char in versions prior to 5.5.10 and 5.6.2. 

1194 # ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake 

1195 # didn't use version checks as mariadb is corrected and reports 

1196 # earlier than those two. 

1197 server_end = data.find(b"\0", i) 

1198 if server_end < 0: # pragma: no cover - very specific upstream bug 

1199 # not found \0 and last field so take it all 

1200 self._auth_plugin_name = data[i:].decode("utf-8") 

1201 else: 

1202 self._auth_plugin_name = data[i:server_end].decode("utf-8") 

1203 

1204 if _DEFAULT_AUTH_PLUGIN is not None: # for tests 

1205 self._auth_plugin_name = _DEFAULT_AUTH_PLUGIN 

1206 

1207 def get_server_info(self): 

1208 return self.server_version 

1209 

1210 Warning = err.Warning 

1211 Error = err.Error 

1212 InterfaceError = err.InterfaceError 

1213 DatabaseError = err.DatabaseError 

1214 DataError = err.DataError 

1215 OperationalError = err.OperationalError 

1216 IntegrityError = err.IntegrityError 

1217 InternalError = err.InternalError 

1218 ProgrammingError = err.ProgrammingError 

1219 NotSupportedError = err.NotSupportedError 

1220 

1221 

1222class MySQLResult: 

1223 def __init__(self, connection): 

1224 """ 

1225 :type connection: Connection 

1226 """ 

1227 self.connection = connection 

1228 self.affected_rows = None 

1229 self.insert_id = None 

1230 self.server_status = None 

1231 self.warning_count = 0 

1232 self.message = None 

1233 self.field_count = 0 

1234 self.description = None 

1235 self.rows = None 

1236 self.has_next = None 

1237 self.unbuffered_active = False 

1238 

1239 def __del__(self): 

1240 if self.unbuffered_active: 

1241 self._finish_unbuffered_query() 

1242 

1243 def read(self): 

1244 try: 

1245 first_packet = self.connection._read_packet() 

1246 

1247 if first_packet.is_ok_packet(): 

1248 self._read_ok_packet(first_packet) 

1249 elif first_packet.is_load_local_packet(): 

1250 self._read_load_local_packet(first_packet) 

1251 else: 

1252 self._read_result_packet(first_packet) 

1253 finally: 

1254 self.connection = None 

1255 

1256 def init_unbuffered_query(self): 

1257 """ 

1258 :raise OperationalError: If the connection to the MySQL server is lost. 

1259 :raise InternalError: 

1260 """ 

1261 first_packet = self.connection._read_packet() 

1262 

1263 if first_packet.is_ok_packet(): 

1264 self.connection = None 

1265 self._read_ok_packet(first_packet) 

1266 elif first_packet.is_load_local_packet(): 

1267 try: 

1268 self._read_load_local_packet(first_packet) 

1269 finally: 

1270 self.connection = None 

1271 else: 

1272 self.field_count = first_packet.read_length_encoded_integer() 

1273 self._get_descriptions() 

1274 

1275 # Apparently, MySQLdb picks this number because it's the maximum 

1276 # value of a 64bit unsigned integer. Since we're emulating MySQLdb, 

1277 # we set it to this instead of None, which would be preferred. 

1278 self.affected_rows = 18446744073709551615 

1279 self.unbuffered_active = True 

1280 

1281 def _read_ok_packet(self, first_packet): 

1282 ok_packet = OKPacketWrapper(first_packet) 

1283 self.affected_rows = ok_packet.affected_rows 

1284 self.insert_id = ok_packet.insert_id 

1285 self.server_status = ok_packet.server_status 

1286 self.warning_count = ok_packet.warning_count 

1287 self.message = ok_packet.message 

1288 self.has_next = ok_packet.has_next 

1289 

1290 def _read_load_local_packet(self, first_packet): 

1291 if not self.connection._local_infile: 

1292 raise RuntimeError( 

1293 "**WARN**: Received LOAD_LOCAL packet but local_infile option is false." 

1294 ) 

1295 load_packet = LoadLocalPacketWrapper(first_packet) 

1296 sender = LoadLocalFile(load_packet.filename, self.connection) 

1297 try: 

1298 sender.send_data() 

1299 except: 

1300 self.connection._read_packet() # skip ok packet 

1301 raise 

1302 

1303 ok_packet = self.connection._read_packet() 

1304 if ( 

1305 not ok_packet.is_ok_packet() 

1306 ): # pragma: no cover - upstream induced protocol error 

1307 raise err.OperationalError( 

1308 CR.CR_COMMANDS_OUT_OF_SYNC, 

1309 "Commands Out of Sync", 

1310 ) 

1311 self._read_ok_packet(ok_packet) 

1312 

1313 def _check_packet_is_eof(self, packet): 

1314 if not packet.is_eof_packet(): 

1315 return False 

1316 # TODO: Support CLIENT.DEPRECATE_EOF 

1317 # 1) Add DEPRECATE_EOF to CAPABILITIES 

1318 # 2) Mask CAPABILITIES with server_capabilities 

1319 # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: 

1320 # use OKPacketWrapper instead of EOFPacketWrapper 

1321 wp = EOFPacketWrapper(packet) 

1322 self.warning_count = wp.warning_count 

1323 self.has_next = wp.has_next 

1324 return True 

1325 

1326 def _read_result_packet(self, first_packet): 

1327 self.field_count = first_packet.read_length_encoded_integer() 

1328 self._get_descriptions() 

1329 self._read_rowdata_packet() 

1330 

1331 def _read_rowdata_packet_unbuffered(self): 

1332 # Check if in an active query 

1333 if not self.unbuffered_active: 

1334 return 

1335 

1336 # EOF 

1337 packet = self.connection._read_packet() 

1338 if self._check_packet_is_eof(packet): 

1339 self.unbuffered_active = False 

1340 self.connection = None 

1341 self.rows = None 

1342 return 

1343 

1344 row = self._read_row_from_packet(packet) 

1345 self.affected_rows = 1 

1346 self.rows = (row,) # rows should tuple of row for MySQL-python compatibility. 

1347 return row 

1348 

1349 def _finish_unbuffered_query(self): 

1350 # After much reading on the MySQL protocol, it appears that there is, 

1351 # in fact, no way to stop MySQL from sending all the data after 

1352 # executing a query, so we just spin, and wait for an EOF packet. 

1353 while self.unbuffered_active: 

1354 try: 

1355 packet = self.connection._read_packet() 

1356 except err.OperationalError as e: 

1357 if e.args[0] in ( 

1358 ER.QUERY_TIMEOUT, 

1359 ER.STATEMENT_TIMEOUT, 

1360 ): 

1361 # if the query timed out we can simply ignore this error 

1362 self.unbuffered_active = False 

1363 self.connection = None 

1364 return 

1365 

1366 raise 

1367 

1368 if self._check_packet_is_eof(packet): 

1369 self.unbuffered_active = False 

1370 self.connection = None # release reference to kill cyclic reference. 

1371 

1372 def _read_rowdata_packet(self): 

1373 """Read a rowdata packet for each data row in the result set.""" 

1374 rows = [] 

1375 while True: 

1376 packet = self.connection._read_packet() 

1377 if self._check_packet_is_eof(packet): 

1378 self.connection = None # release reference to kill cyclic reference. 

1379 break 

1380 rows.append(self._read_row_from_packet(packet)) 

1381 

1382 self.affected_rows = len(rows) 

1383 self.rows = tuple(rows) 

1384 

1385 def _read_row_from_packet(self, packet): 

1386 row = [] 

1387 for encoding, converter in self.converters: 

1388 try: 

1389 data = packet.read_length_coded_string() 

1390 except IndexError: 

1391 # No more columns in this row 

1392 # See https://github.com/PyMySQL/PyMySQL/pull/434 

1393 break 

1394 if data is not None: 

1395 if encoding is not None: 

1396 data = data.decode(encoding) 

1397 if DEBUG: 

1398 print("DEBUG: DATA = ", data) 

1399 if converter is not None: 

1400 data = converter(data) 

1401 row.append(data) 

1402 return tuple(row) 

1403 

1404 def _get_descriptions(self): 

1405 """Read a column descriptor packet for each column in the result.""" 

1406 self.fields = [] 

1407 self.converters = [] 

1408 use_unicode = self.connection.use_unicode 

1409 conn_encoding = self.connection.encoding 

1410 description = [] 

1411 

1412 for i in range(self.field_count): 

1413 field = self.connection._read_packet(FieldDescriptorPacket) 

1414 self.fields.append(field) 

1415 description.append(field.description()) 

1416 field_type = field.type_code 

1417 if use_unicode: 

1418 if field_type == FIELD_TYPE.JSON: 

1419 # When SELECT from JSON column: charset = binary 

1420 # When SELECT CAST(... AS JSON): charset = connection encoding 

1421 # This behavior is different from TEXT / BLOB. 

1422 # We should decode result by connection encoding regardless charsetnr. 

1423 # See https://github.com/PyMySQL/PyMySQL/issues/488 

1424 encoding = conn_encoding # SELECT CAST(... AS JSON) 

1425 elif field_type in TEXT_TYPES: 

1426 if field.charsetnr == 63: # binary 

1427 # TEXTs with charset=binary means BINARY types. 

1428 encoding = None 

1429 else: 

1430 encoding = conn_encoding 

1431 else: 

1432 # Integers, Dates and Times, and other basic data is encoded in ascii 

1433 encoding = "ascii" 

1434 else: 

1435 encoding = None 

1436 converter = self.connection.decoders.get(field_type) 

1437 if converter is converters.through: 

1438 converter = None 

1439 if DEBUG: 

1440 print(f"DEBUG: field={field}, converter={converter}") 

1441 self.converters.append((encoding, converter)) 

1442 

1443 eof_packet = self.connection._read_packet() 

1444 assert eof_packet.is_eof_packet(), "Protocol error, expecting EOF" 

1445 self.description = tuple(description) 

1446 

1447 

1448class LoadLocalFile: 

1449 def __init__(self, filename, connection): 

1450 self.filename = filename 

1451 self.connection = connection 

1452 

1453 def send_data(self): 

1454 """Send data packets from the local file to the server""" 

1455 if not self.connection._sock: 

1456 raise err.InterfaceError(0, "") 

1457 conn: Connection = self.connection 

1458 

1459 try: 

1460 with open(self.filename, "rb") as open_file: 

1461 packet_size = min( 

1462 conn.max_allowed_packet, 16 * 1024 

1463 ) # 16KB is efficient enough 

1464 while True: 

1465 chunk = open_file.read(packet_size) 

1466 if not chunk: 

1467 break 

1468 conn.write_packet(chunk) 

1469 except OSError: 

1470 raise err.OperationalError( 

1471 ER.FILE_NOT_FOUND, 

1472 f"Can't find file '{self.filename}'", 

1473 ) 

1474 finally: 

1475 if not conn._closed: 

1476 # send the empty packet to signify we are done sending data 

1477 conn.write_packet(b"")