Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/pymysql/connections.py: 27%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

836 statements  

1# Python implementation of the MySQL client-server protocol 

2# http://dev.mysql.com/doc/internals/en/client-server-protocol.html 

3# Error codes: 

4# https://dev.mysql.com/doc/refman/5.5/en/error-handling.html 

5import errno 

6import os 

7import socket 

8import struct 

9import sys 

10import traceback 

11import warnings 

12 

13from . import _auth 

14 

15from .charset import charset_by_name, charset_by_id 

16from .constants import CLIENT, COMMAND, CR, ER, FIELD_TYPE, SERVER_STATUS 

17from . import converters 

18from .cursors import Cursor 

19from .optionfile import Parser 

20from .protocol import ( 

21 dump_packet, 

22 MysqlPacket, 

23 FieldDescriptorPacket, 

24 OKPacketWrapper, 

25 EOFPacketWrapper, 

26 LoadLocalPacketWrapper, 

27) 

28from . import err, VERSION_STRING 

29 

30try: 

31 import ssl 

32 

33 SSL_ENABLED = True 

34except ImportError: 

35 ssl = None 

36 SSL_ENABLED = False 

37 

38try: 

39 import getpass 

40 

41 DEFAULT_USER = getpass.getuser() 

42 del getpass 

43except (ImportError, KeyError, OSError): 

44 # When there's no entry in OS database for a current user: 

45 # KeyError is raised in Python 3.12 and below. 

46 # OSError is raised in Python 3.13+ 

47 DEFAULT_USER = None 

48 

49DEBUG = False 

50_DEFAULT_AUTH_PLUGIN = None # if this is not None, use it instead of server's default. 

51 

52TEXT_TYPES = { 

53 FIELD_TYPE.BIT, 

54 FIELD_TYPE.BLOB, 

55 FIELD_TYPE.LONG_BLOB, 

56 FIELD_TYPE.MEDIUM_BLOB, 

57 FIELD_TYPE.STRING, 

58 FIELD_TYPE.TINY_BLOB, 

59 FIELD_TYPE.VAR_STRING, 

60 FIELD_TYPE.VARCHAR, 

61 FIELD_TYPE.GEOMETRY, 

62} 

63 

64 

65DEFAULT_CHARSET = "utf8mb4" 

66 

67MAX_PACKET_LEN = 2**24 - 1 

68 

69 

70def _pack_int24(n): 

71 return struct.pack("<I", n)[:3] 

72 

73 

74# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger 

75def _lenenc_int(i): 

76 if i < 0: 

77 raise ValueError( 

78 "Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i 

79 ) 

80 elif i < 0xFB: 

81 return bytes([i]) 

82 elif i < (1 << 16): 

83 return b"\xfc" + struct.pack("<H", i) 

84 elif i < (1 << 24): 

85 return b"\xfd" + struct.pack("<I", i)[:3] 

86 elif i < (1 << 64): 

87 return b"\xfe" + struct.pack("<Q", i) 

88 else: 

89 raise ValueError( 

90 f"Encoding {i:x} is larger than {1 << 64:x} - no representation in LengthEncodedInteger" 

91 ) 

92 

93 

94class Connection: 

95 """ 

96 Representation of a socket with a mysql server. 

97 

98 The proper way to get an instance of this class is to call 

99 connect(). 

100 

101 Establish a connection to the MySQL database. Accepts several 

102 arguments: 

103 

104 :param host: Host where the database server is located. 

105 :param user: Username to log in as. 

106 :param password: Password to use. 

107 :param database: Database to use, None to not use a particular one. 

108 :param port: MySQL port to use, default is usually OK. (default: 3306) 

109 :param bind_address: When the client has multiple network interfaces, specify 

110 the interface from which to connect to the host. Argument can be 

111 a hostname or an IP address. 

112 :param unix_socket: Use a unix socket rather than TCP/IP. 

113 :param read_timeout: The timeout for reading from the connection in seconds. 

114 (default: None - no timeout) 

115 :param write_timeout: The timeout for writing to the connection in seconds. 

116 (default: None - no timeout) 

117 :param str charset: Charset to use. 

118 :param str collation: Collation name to use. 

119 :param sql_mode: Default SQL_MODE to use. 

120 :param read_default_file: 

121 Specifies my.cnf file to read these parameters from under the [client] section. 

122 :param conv: 

123 Conversion dictionary to use instead of the default one. 

124 This is used to provide custom marshalling and unmarshalling of types. 

125 See converters. 

126 :param use_unicode: 

127 Whether or not to default to unicode strings. 

128 This option defaults to true. 

129 :param client_flag: Custom flags to send to MySQL. Find potential values in constants.CLIENT. 

130 :param cursorclass: Custom cursor class to use. 

131 :param init_command: Initial SQL statement to run when connection is established. 

132 :param connect_timeout: The timeout for connecting to the database in seconds. 

133 (default: 10, min: 1, max: 31536000) 

134 :param ssl: A dict of arguments similar to mysql_ssl_set()'s parameters or an ssl.SSLContext. 

135 :param ssl_ca: Path to the file that contains a PEM-formatted CA certificate. 

136 :param ssl_cert: Path to the file that contains a PEM-formatted client certificate. 

137 :param ssl_disabled: A boolean value that disables usage of TLS. 

138 :param ssl_key: Path to the file that contains a PEM-formatted private key for 

139 the client certificate. 

140 :param ssl_key_password: The password for the client certificate private key. 

141 :param ssl_verify_cert: Set to true to check the server certificate's validity. 

142 :param ssl_verify_identity: Set to true to check the server's identity. 

143 :param read_default_group: Group to read from in the configuration file. 

144 :param autocommit: Autocommit mode. None means use server default. (default: False) 

145 :param local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False) 

146 :param max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB) 

147 Only used to limit size of "LOAD LOCAL INFILE" data packet smaller than default (16KB). 

148 :param defer_connect: Don't explicitly connect on construction - wait for connect call. 

149 (default: False) 

150 :param auth_plugin_map: A dict of plugin names to a class that processes that plugin. 

151 The class will take the Connection object as the argument to the constructor. 

152 The class needs an authenticate method taking an authentication packet as 

153 an argument. For the dialog plugin, a prompt(echo, prompt) method can be used 

154 (if no authenticate method) for returning a string from the user. (experimental) 

155 :param server_public_key: SHA256 authentication plugin public key value. (default: None) 

156 :param binary_prefix: Add _binary prefix on bytes and bytearray. (default: False) 

157 :param compress: Not supported. 

158 :param named_pipe: Not supported. 

159 :param db: **DEPRECATED** Alias for database. 

160 :param passwd: **DEPRECATED** Alias for password. 

161 

162 See `Connection <https://www.python.org/dev/peps/pep-0249/#connection-objects>`_ in the 

163 specification. 

164 """ 

165 

166 _sock = None 

167 _rfile = None 

168 _auth_plugin_name = "" 

169 _closed = False 

170 _secure = False 

171 

172 def __init__( 

173 self, 

174 *, 

175 user=None, # The first four arguments is based on DB-API 2.0 recommendation. 

176 password="", 

177 host=None, 

178 database=None, 

179 unix_socket=None, 

180 port=0, 

181 charset="", 

182 collation=None, 

183 sql_mode=None, 

184 read_default_file=None, 

185 conv=None, 

186 use_unicode=True, 

187 client_flag=0, 

188 cursorclass=Cursor, 

189 init_command=None, 

190 connect_timeout=10, 

191 read_default_group=None, 

192 autocommit=False, 

193 local_infile=False, 

194 max_allowed_packet=16 * 1024 * 1024, 

195 defer_connect=False, 

196 auth_plugin_map=None, 

197 read_timeout=None, 

198 write_timeout=None, 

199 bind_address=None, 

200 binary_prefix=False, 

201 program_name=None, 

202 server_public_key=None, 

203 ssl=None, 

204 ssl_ca=None, 

205 ssl_cert=None, 

206 ssl_disabled=None, 

207 ssl_key=None, 

208 ssl_key_password=None, 

209 ssl_verify_cert=None, 

210 ssl_verify_identity=None, 

211 compress=None, # not supported 

212 named_pipe=None, # not supported 

213 passwd=None, # deprecated 

214 db=None, # deprecated 

215 ): 

216 if db is not None and database is None: 

217 # We will raise warning in 2022 or later. 

218 # See https://github.com/PyMySQL/PyMySQL/issues/939 

219 # warnings.warn("'db' is deprecated, use 'database'", DeprecationWarning, 3) 

220 database = db 

221 if passwd is not None and not password: 

222 # We will raise warning in 2022 or later. 

223 # See https://github.com/PyMySQL/PyMySQL/issues/939 

224 # warnings.warn( 

225 # "'passwd' is deprecated, use 'password'", DeprecationWarning, 3 

226 # ) 

227 password = passwd 

228 

229 if compress or named_pipe: 

230 raise NotImplementedError( 

231 "compress and named_pipe arguments are not supported" 

232 ) 

233 

234 self._local_infile = bool(local_infile) 

235 if self._local_infile: 

236 client_flag |= CLIENT.LOCAL_FILES 

237 

238 if read_default_group and not read_default_file: 

239 if sys.platform.startswith("win"): 

240 read_default_file = "c:\\my.ini" 

241 else: 

242 read_default_file = "/etc/my.cnf" 

243 

244 if read_default_file: 

245 if not read_default_group: 

246 read_default_group = "client" 

247 

248 cfg = Parser() 

249 cfg.read(os.path.expanduser(read_default_file)) 

250 

251 def _config(key, arg): 

252 if arg: 

253 return arg 

254 try: 

255 return cfg.get(read_default_group, key) 

256 except Exception: 

257 return arg 

258 

259 user = _config("user", user) 

260 password = _config("password", password) 

261 host = _config("host", host) 

262 database = _config("database", database) 

263 unix_socket = _config("socket", unix_socket) 

264 port = int(_config("port", port)) 

265 bind_address = _config("bind-address", bind_address) 

266 charset = _config("default-character-set", charset) 

267 if not ssl: 

268 ssl = {} 

269 if isinstance(ssl, dict): 

270 for key in ["ca", "capath", "cert", "key", "password", "cipher"]: 

271 value = _config("ssl-" + key, ssl.get(key)) 

272 if value: 

273 ssl[key] = value 

274 

275 self.ssl = False 

276 if not ssl_disabled: 

277 if ssl_ca or ssl_cert or ssl_key or ssl_verify_cert or ssl_verify_identity: 

278 ssl = { 

279 "ca": ssl_ca, 

280 "check_hostname": bool(ssl_verify_identity), 

281 "verify_mode": ssl_verify_cert 

282 if ssl_verify_cert is not None 

283 else False, 

284 } 

285 if ssl_cert is not None: 

286 ssl["cert"] = ssl_cert 

287 if ssl_key is not None: 

288 ssl["key"] = ssl_key 

289 if ssl_key_password is not None: 

290 ssl["password"] = ssl_key_password 

291 if ssl: 

292 if not SSL_ENABLED: 

293 raise NotImplementedError("ssl module not found") 

294 self.ssl = True 

295 client_flag |= CLIENT.SSL 

296 self.ctx = self._create_ssl_ctx(ssl) 

297 

298 self.host = host or "localhost" 

299 self.port = port or 3306 

300 if type(self.port) is not int: 

301 raise ValueError("port should be of type int") 

302 self.user = user or DEFAULT_USER 

303 self.password = password or b"" 

304 if isinstance(self.password, str): 

305 self.password = self.password.encode("latin1") 

306 self.db = database 

307 self.unix_socket = unix_socket 

308 self.bind_address = bind_address 

309 if not (0 < connect_timeout <= 31536000): 

310 raise ValueError("connect_timeout should be >0 and <=31536000") 

311 self.connect_timeout = connect_timeout or None 

312 if read_timeout is not None and read_timeout <= 0: 

313 raise ValueError("read_timeout should be > 0") 

314 self._read_timeout = read_timeout 

315 if write_timeout is not None and write_timeout <= 0: 

316 raise ValueError("write_timeout should be > 0") 

317 self._write_timeout = write_timeout 

318 

319 self.charset = charset or DEFAULT_CHARSET 

320 self.collation = collation 

321 self.use_unicode = use_unicode 

322 

323 self.encoding = charset_by_name(self.charset).encoding 

324 

325 client_flag |= CLIENT.CAPABILITIES 

326 if self.db: 

327 client_flag |= CLIENT.CONNECT_WITH_DB 

328 

329 self.client_flag = client_flag 

330 

331 self.cursorclass = cursorclass 

332 

333 self._result = None 

334 self._affected_rows = 0 

335 self.host_info = "Not connected" 

336 

337 # specified autocommit mode. None means use server default. 

338 self.autocommit_mode = autocommit 

339 

340 if conv is None: 

341 conv = converters.conversions 

342 

343 # Need for MySQLdb compatibility. 

344 self.encoders = {k: v for (k, v) in conv.items() if type(k) is not int} 

345 self.decoders = {k: v for (k, v) in conv.items() if type(k) is int} 

346 self.sql_mode = sql_mode 

347 self.init_command = init_command 

348 self.max_allowed_packet = max_allowed_packet 

349 self._auth_plugin_map = auth_plugin_map or {} 

350 self._binary_prefix = binary_prefix 

351 self.server_public_key = server_public_key 

352 

353 self._connect_attrs = { 

354 "_client_name": "pymysql", 

355 "_client_version": VERSION_STRING, 

356 "_pid": str(os.getpid()), 

357 } 

358 

359 if program_name: 

360 self._connect_attrs["program_name"] = program_name 

361 

362 if defer_connect: 

363 self._sock = None 

364 else: 

365 self.connect() 

366 

367 def __enter__(self): 

368 return self 

369 

370 def __exit__(self, *exc_info): 

371 del exc_info 

372 self.close() 

373 

374 def _create_ssl_ctx(self, sslp): 

375 if isinstance(sslp, ssl.SSLContext): 

376 return sslp 

377 ca = sslp.get("ca") 

378 capath = sslp.get("capath") 

379 hasnoca = ca is None and capath is None 

380 ctx = ssl.create_default_context(cafile=ca, capath=capath) 

381 

382 # Python 3.13 enables VERIFY_X509_STRICT by default. 

383 # But self signed certificates that are generated by MySQL automatically 

384 # doesn't pass the verification. 

385 ctx.verify_flags &= ~ssl.VERIFY_X509_STRICT 

386 

387 ctx.check_hostname = not hasnoca and sslp.get("check_hostname", True) 

388 verify_mode_value = sslp.get("verify_mode") 

389 if verify_mode_value is None: 

390 ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED 

391 elif isinstance(verify_mode_value, bool): 

392 ctx.verify_mode = ssl.CERT_REQUIRED if verify_mode_value else ssl.CERT_NONE 

393 else: 

394 if isinstance(verify_mode_value, str): 

395 verify_mode_value = verify_mode_value.lower() 

396 if verify_mode_value in ("none", "0", "false", "no"): 

397 ctx.verify_mode = ssl.CERT_NONE 

398 elif verify_mode_value == "optional": 

399 ctx.verify_mode = ssl.CERT_OPTIONAL 

400 elif verify_mode_value in ("required", "1", "true", "yes"): 

401 ctx.verify_mode = ssl.CERT_REQUIRED 

402 else: 

403 ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED 

404 if "cert" in sslp: 

405 ctx.load_cert_chain( 

406 sslp["cert"], keyfile=sslp.get("key"), password=sslp.get("password") 

407 ) 

408 if "cipher" in sslp: 

409 ctx.set_ciphers(sslp["cipher"]) 

410 ctx.options |= ssl.OP_NO_SSLv2 

411 ctx.options |= ssl.OP_NO_SSLv3 

412 return ctx 

413 

414 def close(self): 

415 """ 

416 Send the quit message and close the socket. 

417 

418 See `Connection.close() <https://www.python.org/dev/peps/pep-0249/#Connection.close>`_ 

419 in the specification. 

420 

421 :raise Error: If the connection is already closed. 

422 """ 

423 if self._closed: 

424 raise err.Error("Already closed") 

425 self._closed = True 

426 if self._sock is None: 

427 return 

428 send_data = struct.pack("<iB", 1, COMMAND.COM_QUIT) 

429 try: 

430 self._write_bytes(send_data) 

431 except Exception: 

432 pass 

433 finally: 

434 self._force_close() 

435 

436 @property 

437 def open(self): 

438 """Return True if the connection is open.""" 

439 return self._sock is not None 

440 

441 def _force_close(self): 

442 """Close connection without QUIT message.""" 

443 if self._rfile: 

444 self._rfile.close() 

445 if self._sock: 

446 try: 

447 self._sock.close() 

448 except: # noqa 

449 pass 

450 self._sock = None 

451 self._rfile = None 

452 

453 __del__ = _force_close 

454 

455 def autocommit(self, value): 

456 self.autocommit_mode = bool(value) 

457 current = self.get_autocommit() 

458 if value != current: 

459 self._send_autocommit_mode() 

460 

461 def get_autocommit(self): 

462 return bool(self.server_status & SERVER_STATUS.SERVER_STATUS_AUTOCOMMIT) 

463 

464 def _read_ok_packet(self): 

465 pkt = self._read_packet() 

466 if not pkt.is_ok_packet(): 

467 raise err.OperationalError( 

468 CR.CR_COMMANDS_OUT_OF_SYNC, 

469 "Command Out of Sync", 

470 ) 

471 ok = OKPacketWrapper(pkt) 

472 self.server_status = ok.server_status 

473 return ok 

474 

475 def _send_autocommit_mode(self): 

476 """Set whether or not to commit after every execute().""" 

477 self._execute_command( 

478 COMMAND.COM_QUERY, "SET AUTOCOMMIT = %s" % self.escape(self.autocommit_mode) 

479 ) 

480 self._read_ok_packet() 

481 

482 def begin(self): 

483 """Begin transaction.""" 

484 self._execute_command(COMMAND.COM_QUERY, "BEGIN") 

485 self._read_ok_packet() 

486 

487 def commit(self): 

488 """ 

489 Commit changes to stable storage. 

490 

491 See `Connection.commit() <https://www.python.org/dev/peps/pep-0249/#commit>`_ 

492 in the specification. 

493 """ 

494 self._execute_command(COMMAND.COM_QUERY, "COMMIT") 

495 self._read_ok_packet() 

496 

497 def rollback(self): 

498 """ 

499 Roll back the current transaction. 

500 

501 See `Connection.rollback() <https://www.python.org/dev/peps/pep-0249/#rollback>`_ 

502 in the specification. 

503 """ 

504 self._execute_command(COMMAND.COM_QUERY, "ROLLBACK") 

505 self._read_ok_packet() 

506 

507 def show_warnings(self): 

508 """Send the "SHOW WARNINGS" SQL command.""" 

509 self._execute_command(COMMAND.COM_QUERY, "SHOW WARNINGS") 

510 result = MySQLResult(self) 

511 result.read() 

512 return result.rows 

513 

514 def select_db(self, db): 

515 """ 

516 Set current db. 

517 

518 :param db: The name of the db. 

519 """ 

520 self._execute_command(COMMAND.COM_INIT_DB, db) 

521 self._read_ok_packet() 

522 

523 def escape(self, obj, mapping=None): 

524 """Escape whatever value is passed. 

525 

526 Non-standard, for internal use; do not use this in your applications. 

527 """ 

528 if isinstance(obj, str): 

529 return "'" + self.escape_string(obj) + "'" 

530 if isinstance(obj, (bytes, bytearray)): 

531 ret = self._quote_bytes(obj) 

532 if self._binary_prefix: 

533 ret = "_binary" + ret 

534 return ret 

535 return converters.escape_item(obj, self.charset, mapping=mapping) 

536 

537 def literal(self, obj): 

538 """Alias for escape(). 

539 

540 Non-standard, for internal use; do not use this in your applications. 

541 """ 

542 return self.escape(obj, self.encoders) 

543 

544 def escape_string(self, s): 

545 if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES: 

546 return s.replace("'", "''") 

547 return converters.escape_string(s) 

548 

549 def _quote_bytes(self, s): 

550 if self.server_status & SERVER_STATUS.SERVER_STATUS_NO_BACKSLASH_ESCAPES: 

551 return "'{}'".format( 

552 s.replace(b"'", b"''").decode("ascii", "surrogateescape") 

553 ) 

554 return converters.escape_bytes(s) 

555 

556 def cursor(self, cursor=None): 

557 """ 

558 Create a new cursor to execute queries with. 

559 

560 :param cursor: The type of cursor to create. None means use Cursor. 

561 :type cursor: :py:class:`Cursor`, :py:class:`SSCursor`, :py:class:`DictCursor`, 

562 or :py:class:`SSDictCursor`. 

563 """ 

564 if cursor: 

565 return cursor(self) 

566 return self.cursorclass(self) 

567 

568 # The following methods are INTERNAL USE ONLY (called from Cursor) 

569 def query(self, sql, unbuffered=False): 

570 # if DEBUG: 

571 # print("DEBUG: sending query:", sql) 

572 if isinstance(sql, str): 

573 sql = sql.encode(self.encoding, "surrogateescape") 

574 self._execute_command(COMMAND.COM_QUERY, sql) 

575 self._affected_rows = self._read_query_result(unbuffered=unbuffered) 

576 return self._affected_rows 

577 

578 def next_result(self, unbuffered=False): 

579 self._affected_rows = self._read_query_result(unbuffered=unbuffered) 

580 return self._affected_rows 

581 

582 def affected_rows(self): 

583 return self._affected_rows 

584 

585 def kill(self, thread_id): 

586 if not isinstance(thread_id, int): 

587 raise TypeError("thread_id must be an integer") 

588 self.query(f"KILL {thread_id:d}") 

589 

590 def ping(self, reconnect=True): 

591 """ 

592 Check if the server is alive. 

593 

594 :param reconnect: If the connection is closed, reconnect. 

595 :type reconnect: boolean 

596 

597 :raise Error: If the connection is closed and reconnect=False. 

598 """ 

599 if self._sock is None: 

600 if reconnect: 

601 self.connect() 

602 reconnect = False 

603 else: 

604 raise err.Error("Already closed") 

605 try: 

606 self._execute_command(COMMAND.COM_PING, "") 

607 self._read_ok_packet() 

608 except Exception: 

609 if reconnect: 

610 self.connect() 

611 self.ping(False) 

612 else: 

613 raise 

614 

615 def set_charset(self, charset): 

616 """Deprecated. Use set_character_set() instead.""" 

617 # This function has been implemented in old PyMySQL. 

618 # But this name is different from MySQLdb. 

619 # So we keep this function for compatibility and add 

620 # new set_character_set() function. 

621 self.set_character_set(charset) 

622 

623 def set_character_set(self, charset, collation=None): 

624 """ 

625 Set charaset (and collation) 

626 

627 Send "SET NAMES charset [COLLATE collation]" query. 

628 Update Connection.encoding based on charset. 

629 """ 

630 # Make sure charset is supported. 

631 encoding = charset_by_name(charset).encoding 

632 

633 if collation: 

634 query = f"SET NAMES {charset} COLLATE {collation}" 

635 else: 

636 query = f"SET NAMES {charset}" 

637 self._execute_command(COMMAND.COM_QUERY, query) 

638 self._read_packet() 

639 self.charset = charset 

640 self.encoding = encoding 

641 self.collation = collation 

642 

643 def connect(self, sock=None): 

644 self._closed = False 

645 try: 

646 if sock is None: 

647 if self.unix_socket: 

648 sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) 

649 sock.settimeout(self.connect_timeout) 

650 sock.connect(self.unix_socket) 

651 self.host_info = "Localhost via UNIX socket" 

652 self._secure = True 

653 if DEBUG: 

654 print("connected using unix_socket") 

655 else: 

656 kwargs = {} 

657 if self.bind_address is not None: 

658 kwargs["source_address"] = (self.bind_address, 0) 

659 while True: 

660 try: 

661 sock = socket.create_connection( 

662 (self.host, self.port), self.connect_timeout, **kwargs 

663 ) 

664 break 

665 except OSError as e: 

666 if e.errno == errno.EINTR: 

667 continue 

668 raise 

669 self.host_info = "socket %s:%d" % (self.host, self.port) 

670 if DEBUG: 

671 print("connected using socket") 

672 sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) 

673 sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) 

674 sock.settimeout(None) 

675 

676 self._sock = sock 

677 self._rfile = sock.makefile("rb") 

678 self._next_seq_id = 0 

679 

680 self._get_server_information() 

681 self._request_authentication() 

682 

683 # Send "SET NAMES" query on init for: 

684 # - Ensure charaset (and collation) is set to the server. 

685 # - collation_id in handshake packet may be ignored. 

686 # - If collation is not specified, we don't know what is server's 

687 # default collation for the charset. For example, default collation 

688 # of utf8mb4 is: 

689 # - MySQL 5.7, MariaDB 10.x: utf8mb4_general_ci 

690 # - MySQL 8.0: utf8mb4_0900_ai_ci 

691 # 

692 # Reference: 

693 # - https://github.com/PyMySQL/PyMySQL/issues/1092 

694 # - https://github.com/wagtail/wagtail/issues/9477 

695 # - https://zenn.dev/methane/articles/2023-mysql-collation (Japanese) 

696 self.set_character_set(self.charset, self.collation) 

697 

698 if self.sql_mode is not None: 

699 c = self.cursor() 

700 c.execute("SET sql_mode=%s", (self.sql_mode,)) 

701 c.close() 

702 

703 if self.init_command is not None: 

704 c = self.cursor() 

705 c.execute(self.init_command) 

706 c.close() 

707 

708 if self.autocommit_mode is not None: 

709 self.autocommit(self.autocommit_mode) 

710 except BaseException as e: 

711 self._force_close() 

712 

713 if isinstance(e, (OSError, IOError)): 

714 exc = err.OperationalError( 

715 CR.CR_CONN_HOST_ERROR, 

716 f"Can't connect to MySQL server on {self.host!r} ({e})", 

717 ) 

718 # Keep original exception and traceback to investigate error. 

719 exc.original_exception = e 

720 exc.traceback = traceback.format_exc() 

721 if DEBUG: 

722 print(exc.traceback) 

723 raise exc 

724 

725 # If e is neither DatabaseError or IOError, It's a bug. 

726 # But raising AssertionError hides original error. 

727 # So just reraise it. 

728 raise 

729 

730 def write_packet(self, payload): 

731 """Writes an entire "mysql packet" in its entirety to the network 

732 adding its length and sequence number. 

733 """ 

734 # Internal note: when you build packet manually and calls _write_bytes() 

735 # directly, you should set self._next_seq_id properly. 

736 data = _pack_int24(len(payload)) + bytes([self._next_seq_id]) + payload 

737 if DEBUG: 

738 dump_packet(data) 

739 self._write_bytes(data) 

740 self._next_seq_id = (self._next_seq_id + 1) % 256 

741 

742 def _read_packet(self, packet_type=MysqlPacket): 

743 """Read an entire "mysql packet" in its entirety from the network 

744 and return a MysqlPacket type that represents the results. 

745 

746 :raise OperationalError: If the connection to the MySQL server is lost. 

747 :raise InternalError: If the packet sequence number is wrong. 

748 """ 

749 buff = bytearray() 

750 while True: 

751 packet_header = self._read_bytes(4) 

752 # if DEBUG: dump_packet(packet_header) 

753 

754 btrl, btrh, packet_number = struct.unpack("<HBB", packet_header) 

755 bytes_to_read = btrl + (btrh << 16) 

756 if packet_number != self._next_seq_id: 

757 self._force_close() 

758 if packet_number == 0: 

759 # MariaDB sends error packet with seqno==0 when shutdown 

760 raise err.OperationalError( 

761 CR.CR_SERVER_LOST, 

762 "Lost connection to MySQL server during query", 

763 ) 

764 raise err.InternalError( 

765 "Packet sequence number wrong - got %d expected %d" 

766 % (packet_number, self._next_seq_id) 

767 ) 

768 self._next_seq_id = (self._next_seq_id + 1) % 256 

769 

770 recv_data = self._read_bytes(bytes_to_read) 

771 if DEBUG: 

772 dump_packet(recv_data) 

773 buff += recv_data 

774 # https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html 

775 if bytes_to_read < MAX_PACKET_LEN: 

776 break 

777 

778 packet = packet_type(bytes(buff), self.encoding) 

779 if packet.is_error_packet(): 

780 if self._result is not None and self._result.unbuffered_active is True: 

781 self._result.unbuffered_active = False 

782 packet.raise_for_error() 

783 return packet 

784 

785 def _read_bytes(self, num_bytes): 

786 self._sock.settimeout(self._read_timeout) 

787 while True: 

788 try: 

789 data = self._rfile.read(num_bytes) 

790 break 

791 except OSError as e: 

792 if e.errno == errno.EINTR: 

793 continue 

794 self._force_close() 

795 raise err.OperationalError( 

796 CR.CR_SERVER_LOST, 

797 f"Lost connection to MySQL server during query ({e})", 

798 ) 

799 except BaseException: 

800 # Don't convert unknown exception to MySQLError. 

801 self._force_close() 

802 raise 

803 if len(data) < num_bytes: 

804 self._force_close() 

805 raise err.OperationalError( 

806 CR.CR_SERVER_LOST, "Lost connection to MySQL server during query" 

807 ) 

808 return data 

809 

810 def _write_bytes(self, data): 

811 self._sock.settimeout(self._write_timeout) 

812 try: 

813 self._sock.sendall(data) 

814 except OSError as e: 

815 self._force_close() 

816 raise err.OperationalError( 

817 CR.CR_SERVER_GONE_ERROR, f"MySQL server has gone away ({e!r})" 

818 ) 

819 

820 def _read_query_result(self, unbuffered=False): 

821 self._result = None 

822 result = MySQLResult(self) 

823 if unbuffered: 

824 result.init_unbuffered_query() 

825 else: 

826 result.read() 

827 self._result = result 

828 if result.server_status is not None: 

829 self.server_status = result.server_status 

830 return result.affected_rows 

831 

832 def insert_id(self): 

833 if self._result: 

834 return self._result.insert_id 

835 else: 

836 return 0 

837 

838 def _execute_command(self, command, sql): 

839 """ 

840 :raise InterfaceError: If the connection is closed. 

841 :raise ValueError: If no username was specified. 

842 """ 

843 if not self._sock: 

844 raise err.InterfaceError(0, "") 

845 

846 # If the last query was unbuffered, make sure it finishes before 

847 # sending new commands 

848 if self._result is not None: 

849 if self._result.unbuffered_active: 

850 warnings.warn("Previous unbuffered result was left incomplete") 

851 self._result._finish_unbuffered_query() 

852 while self._result.has_next: 

853 self.next_result() 

854 self._result = None 

855 

856 if isinstance(sql, str): 

857 sql = sql.encode(self.encoding) 

858 

859 packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command 

860 

861 # tiny optimization: build first packet manually instead of 

862 # calling self..write_packet() 

863 prelude = struct.pack("<iB", packet_size, command) 

864 packet = prelude + sql[: packet_size - 1] 

865 self._write_bytes(packet) 

866 if DEBUG: 

867 dump_packet(packet) 

868 self._next_seq_id = 1 

869 

870 if packet_size < MAX_PACKET_LEN: 

871 return 

872 

873 sql = sql[packet_size - 1 :] 

874 while True: 

875 packet_size = min(MAX_PACKET_LEN, len(sql)) 

876 self.write_packet(sql[:packet_size]) 

877 sql = sql[packet_size:] 

878 if not sql and packet_size < MAX_PACKET_LEN: 

879 break 

880 

881 def _request_authentication(self): 

882 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse 

883 if int(self.server_version.split(".", 1)[0]) >= 5: 

884 self.client_flag |= CLIENT.MULTI_RESULTS 

885 

886 if self.user is None: 

887 raise ValueError("Did not specify a username") 

888 

889 charset_id = charset_by_name(self.charset).id 

890 if isinstance(self.user, str): 

891 self.user = self.user.encode(self.encoding) 

892 

893 data_init = struct.pack( 

894 "<iIB23s", self.client_flag, MAX_PACKET_LEN, charset_id, b"" 

895 ) 

896 

897 if self.ssl and self.server_capabilities & CLIENT.SSL: 

898 self.write_packet(data_init) 

899 

900 self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host) 

901 self._rfile = self._sock.makefile("rb") 

902 self._secure = True 

903 

904 data = data_init + self.user + b"\0" 

905 

906 authresp = b"" 

907 plugin_name = None 

908 

909 if self._auth_plugin_name == "": 

910 plugin_name = b"" 

911 authresp = _auth.scramble_native_password(self.password, self.salt) 

912 elif self._auth_plugin_name == "mysql_native_password": 

913 plugin_name = b"mysql_native_password" 

914 authresp = _auth.scramble_native_password(self.password, self.salt) 

915 elif self._auth_plugin_name == "caching_sha2_password": 

916 plugin_name = b"caching_sha2_password" 

917 if self.password: 

918 if DEBUG: 

919 print("caching_sha2: trying fast path") 

920 authresp = _auth.scramble_caching_sha2(self.password, self.salt) 

921 else: 

922 if DEBUG: 

923 print("caching_sha2: empty password") 

924 elif self._auth_plugin_name == "sha256_password": 

925 plugin_name = b"sha256_password" 

926 if self.ssl and self.server_capabilities & CLIENT.SSL: 

927 authresp = self.password + b"\0" 

928 elif self.password: 

929 authresp = b"\1" # request public key 

930 else: 

931 authresp = b"\0" # empty password 

932 

933 if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA: 

934 data += _lenenc_int(len(authresp)) + authresp 

935 elif self.server_capabilities & CLIENT.SECURE_CONNECTION: 

936 data += struct.pack("B", len(authresp)) + authresp 

937 else: # pragma: no cover - not testing against servers without secure auth (>=5.0) 

938 data += authresp + b"\0" 

939 

940 if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB: 

941 if isinstance(self.db, str): 

942 self.db = self.db.encode(self.encoding) 

943 data += self.db + b"\0" 

944 

945 if self.server_capabilities & CLIENT.PLUGIN_AUTH: 

946 data += (plugin_name or b"") + b"\0" 

947 

948 if self.server_capabilities & CLIENT.CONNECT_ATTRS: 

949 connect_attrs = b"" 

950 for k, v in self._connect_attrs.items(): 

951 k = k.encode("utf-8") 

952 connect_attrs += _lenenc_int(len(k)) + k 

953 v = v.encode("utf-8") 

954 connect_attrs += _lenenc_int(len(v)) + v 

955 data += _lenenc_int(len(connect_attrs)) + connect_attrs 

956 

957 self.write_packet(data) 

958 auth_packet = self._read_packet() 

959 

960 # if authentication method isn't accepted the first byte 

961 # will have the octet 254 

962 if auth_packet.is_auth_switch_request(): 

963 if DEBUG: 

964 print("received auth switch") 

965 # https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest 

966 auth_packet.read_uint8() # 0xfe packet identifier 

967 plugin_name = auth_packet.read_string() 

968 if ( 

969 self.server_capabilities & CLIENT.PLUGIN_AUTH 

970 and plugin_name is not None 

971 ): 

972 auth_packet = self._process_auth(plugin_name, auth_packet) 

973 else: 

974 raise err.OperationalError("received unknown auth switch request") 

975 elif auth_packet.is_extra_auth_data(): 

976 if DEBUG: 

977 print("received extra data") 

978 # https://dev.mysql.com/doc/internals/en/successful-authentication.html 

979 if self._auth_plugin_name == "caching_sha2_password": 

980 auth_packet = _auth.caching_sha2_password_auth(self, auth_packet) 

981 elif self._auth_plugin_name == "sha256_password": 

982 auth_packet = _auth.sha256_password_auth(self, auth_packet) 

983 else: 

984 raise err.OperationalError( 

985 "Received extra packet for auth method %r", self._auth_plugin_name 

986 ) 

987 

988 if DEBUG: 

989 print("Succeed to auth") 

990 

991 def _process_auth(self, plugin_name, auth_packet): 

992 handler = self._get_auth_plugin_handler(plugin_name) 

993 if handler: 

994 try: 

995 return handler.authenticate(auth_packet) 

996 except AttributeError: 

997 if plugin_name != b"dialog": 

998 raise err.OperationalError( 

999 CR.CR_AUTH_PLUGIN_CANNOT_LOAD, 

1000 f"Authentication plugin '{plugin_name}'" 

1001 f" not loaded: - {type(handler)!r} missing authenticate method", 

1002 ) 

1003 if plugin_name == b"caching_sha2_password": 

1004 return _auth.caching_sha2_password_auth(self, auth_packet) 

1005 elif plugin_name == b"sha256_password": 

1006 return _auth.sha256_password_auth(self, auth_packet) 

1007 elif plugin_name == b"mysql_native_password": 

1008 data = _auth.scramble_native_password(self.password, auth_packet.read_all()) 

1009 elif plugin_name == b"client_ed25519": 

1010 data = _auth.ed25519_password(self.password, auth_packet.read_all()) 

1011 elif plugin_name == b"mysql_old_password": 

1012 data = ( 

1013 _auth.scramble_old_password(self.password, auth_packet.read_all()) 

1014 + b"\0" 

1015 ) 

1016 elif plugin_name == b"mysql_clear_password": 

1017 # https://dev.mysql.com/doc/internals/en/clear-text-authentication.html 

1018 data = self.password + b"\0" 

1019 elif plugin_name == b"dialog": 

1020 pkt = auth_packet 

1021 while True: 

1022 flag = pkt.read_uint8() 

1023 echo = (flag & 0x06) == 0x02 

1024 last = (flag & 0x01) == 0x01 

1025 prompt = pkt.read_all() 

1026 

1027 if prompt == b"Password: ": 

1028 self.write_packet(self.password + b"\0") 

1029 elif handler: 

1030 resp = "no response - TypeError within plugin.prompt method" 

1031 try: 

1032 resp = handler.prompt(echo, prompt) 

1033 self.write_packet(resp + b"\0") 

1034 except AttributeError: 

1035 raise err.OperationalError( 

1036 CR.CR_AUTH_PLUGIN_CANNOT_LOAD, 

1037 f"Authentication plugin '{plugin_name}'" 

1038 f" not loaded: - {handler!r} missing prompt method", 

1039 ) 

1040 except TypeError: 

1041 raise err.OperationalError( 

1042 CR.CR_AUTH_PLUGIN_ERR, 

1043 f"Authentication plugin '{plugin_name}'" 

1044 f" {handler!r} didn't respond with string. Returned '{resp!r}' to prompt {prompt!r}", 

1045 ) 

1046 else: 

1047 raise err.OperationalError( 

1048 CR.CR_AUTH_PLUGIN_CANNOT_LOAD, 

1049 f"Authentication plugin '{plugin_name}' not configured", 

1050 ) 

1051 pkt = self._read_packet() 

1052 pkt.check_error() 

1053 if pkt.is_ok_packet() or last: 

1054 break 

1055 return pkt 

1056 else: 

1057 raise err.OperationalError( 

1058 CR.CR_AUTH_PLUGIN_CANNOT_LOAD, 

1059 "Authentication plugin '%s' not configured" % plugin_name, 

1060 ) 

1061 

1062 self.write_packet(data) 

1063 pkt = self._read_packet() 

1064 pkt.check_error() 

1065 return pkt 

1066 

1067 def _get_auth_plugin_handler(self, plugin_name): 

1068 plugin_class = self._auth_plugin_map.get(plugin_name) 

1069 if not plugin_class and isinstance(plugin_name, bytes): 

1070 plugin_class = self._auth_plugin_map.get(plugin_name.decode("ascii")) 

1071 if plugin_class: 

1072 try: 

1073 handler = plugin_class(self) 

1074 except TypeError: 

1075 raise err.OperationalError( 

1076 CR.CR_AUTH_PLUGIN_CANNOT_LOAD, 

1077 f"Authentication plugin '{plugin_name}'" 

1078 f" not loaded: - {plugin_class!r} cannot be constructed with connection object", 

1079 ) 

1080 else: 

1081 handler = None 

1082 return handler 

1083 

1084 # _mysql support 

1085 def thread_id(self): 

1086 return self.server_thread_id[0] 

1087 

1088 def character_set_name(self): 

1089 return self.charset 

1090 

1091 def get_host_info(self): 

1092 return self.host_info 

1093 

1094 def get_proto_info(self): 

1095 return self.protocol_version 

1096 

1097 def _get_server_information(self): 

1098 i = 0 

1099 packet = self._read_packet() 

1100 data = packet.get_all_data() 

1101 

1102 self.protocol_version = data[i] 

1103 i += 1 

1104 

1105 server_end = data.find(b"\0", i) 

1106 self.server_version = data[i:server_end].decode("latin1") 

1107 i = server_end + 1 

1108 

1109 self.server_thread_id = struct.unpack("<I", data[i : i + 4]) 

1110 i += 4 

1111 

1112 self.salt = data[i : i + 8] 

1113 i += 9 # 8 + 1(filler) 

1114 

1115 self.server_capabilities = struct.unpack("<H", data[i : i + 2])[0] 

1116 i += 2 

1117 

1118 if len(data) >= i + 6: 

1119 lang, stat, cap_h, salt_len = struct.unpack("<BHHB", data[i : i + 6]) 

1120 i += 6 

1121 # TODO: deprecate server_language and server_charset. 

1122 # mysqlclient-python doesn't provide it. 

1123 self.server_language = lang 

1124 try: 

1125 self.server_charset = charset_by_id(lang).name 

1126 except KeyError: 

1127 # unknown collation 

1128 self.server_charset = None 

1129 

1130 self.server_status = stat 

1131 if DEBUG: 

1132 print("server_status: %x" % stat) 

1133 

1134 self.server_capabilities |= cap_h << 16 

1135 if DEBUG: 

1136 print("salt_len:", salt_len) 

1137 salt_len = max(12, salt_len - 9) 

1138 

1139 # reserved 

1140 i += 10 

1141 

1142 if len(data) >= i + salt_len: 

1143 # salt_len includes auth_plugin_data_part_1 and filler 

1144 self.salt += data[i : i + salt_len] 

1145 i += salt_len 

1146 

1147 i += 1 

1148 # AUTH PLUGIN NAME may appear here. 

1149 if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i: 

1150 # Due to Bug#59453 the auth-plugin-name is missing the terminating 

1151 # NUL-char in versions prior to 5.5.10 and 5.6.2. 

1152 # ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake 

1153 # didn't use version checks as mariadb is corrected and reports 

1154 # earlier than those two. 

1155 server_end = data.find(b"\0", i) 

1156 if server_end < 0: # pragma: no cover - very specific upstream bug 

1157 # not found \0 and last field so take it all 

1158 self._auth_plugin_name = data[i:].decode("utf-8") 

1159 else: 

1160 self._auth_plugin_name = data[i:server_end].decode("utf-8") 

1161 

1162 if _DEFAULT_AUTH_PLUGIN is not None: # for tests 

1163 self._auth_plugin_name = _DEFAULT_AUTH_PLUGIN 

1164 

1165 def get_server_info(self): 

1166 return self.server_version 

1167 

1168 Warning = err.Warning 

1169 Error = err.Error 

1170 InterfaceError = err.InterfaceError 

1171 DatabaseError = err.DatabaseError 

1172 DataError = err.DataError 

1173 OperationalError = err.OperationalError 

1174 IntegrityError = err.IntegrityError 

1175 InternalError = err.InternalError 

1176 ProgrammingError = err.ProgrammingError 

1177 NotSupportedError = err.NotSupportedError 

1178 

1179 

1180class MySQLResult: 

1181 def __init__(self, connection): 

1182 """ 

1183 :type connection: Connection 

1184 """ 

1185 self.connection = connection 

1186 self.affected_rows = None 

1187 self.insert_id = None 

1188 self.server_status = None 

1189 self.warning_count = 0 

1190 self.message = None 

1191 self.field_count = 0 

1192 self.description = None 

1193 self.rows = None 

1194 self.has_next = None 

1195 self.unbuffered_active = False 

1196 

1197 def __del__(self): 

1198 if self.unbuffered_active: 

1199 self._finish_unbuffered_query() 

1200 

1201 def read(self): 

1202 try: 

1203 first_packet = self.connection._read_packet() 

1204 

1205 if first_packet.is_ok_packet(): 

1206 self._read_ok_packet(first_packet) 

1207 elif first_packet.is_load_local_packet(): 

1208 self._read_load_local_packet(first_packet) 

1209 else: 

1210 self._read_result_packet(first_packet) 

1211 finally: 

1212 self.connection = None 

1213 

1214 def init_unbuffered_query(self): 

1215 """ 

1216 :raise OperationalError: If the connection to the MySQL server is lost. 

1217 :raise InternalError: 

1218 """ 

1219 first_packet = self.connection._read_packet() 

1220 

1221 if first_packet.is_ok_packet(): 

1222 self.connection = None 

1223 self._read_ok_packet(first_packet) 

1224 elif first_packet.is_load_local_packet(): 

1225 try: 

1226 self._read_load_local_packet(first_packet) 

1227 finally: 

1228 self.connection = None 

1229 else: 

1230 self.field_count = first_packet.read_length_encoded_integer() 

1231 self._get_descriptions() 

1232 

1233 # Apparently, MySQLdb picks this number because it's the maximum 

1234 # value of a 64bit unsigned integer. Since we're emulating MySQLdb, 

1235 # we set it to this instead of None, which would be preferred. 

1236 self.affected_rows = 18446744073709551615 

1237 self.unbuffered_active = True 

1238 

1239 def _read_ok_packet(self, first_packet): 

1240 ok_packet = OKPacketWrapper(first_packet) 

1241 self.affected_rows = ok_packet.affected_rows 

1242 self.insert_id = ok_packet.insert_id 

1243 self.server_status = ok_packet.server_status 

1244 self.warning_count = ok_packet.warning_count 

1245 self.message = ok_packet.message 

1246 self.has_next = ok_packet.has_next 

1247 

1248 def _read_load_local_packet(self, first_packet): 

1249 if not self.connection._local_infile: 

1250 raise RuntimeError( 

1251 "**WARN**: Received LOAD_LOCAL packet but local_infile option is false." 

1252 ) 

1253 load_packet = LoadLocalPacketWrapper(first_packet) 

1254 sender = LoadLocalFile(load_packet.filename, self.connection) 

1255 try: 

1256 sender.send_data() 

1257 except: 

1258 self.connection._read_packet() # skip ok packet 

1259 raise 

1260 

1261 ok_packet = self.connection._read_packet() 

1262 if ( 

1263 not ok_packet.is_ok_packet() 

1264 ): # pragma: no cover - upstream induced protocol error 

1265 raise err.OperationalError( 

1266 CR.CR_COMMANDS_OUT_OF_SYNC, 

1267 "Commands Out of Sync", 

1268 ) 

1269 self._read_ok_packet(ok_packet) 

1270 

1271 def _check_packet_is_eof(self, packet): 

1272 if not packet.is_eof_packet(): 

1273 return False 

1274 # TODO: Support CLIENT.DEPRECATE_EOF 

1275 # 1) Add DEPRECATE_EOF to CAPABILITIES 

1276 # 2) Mask CAPABILITIES with server_capabilities 

1277 # 3) if server_capabilities & CLIENT.DEPRECATE_EOF: 

1278 # use OKPacketWrapper instead of EOFPacketWrapper 

1279 wp = EOFPacketWrapper(packet) 

1280 self.warning_count = wp.warning_count 

1281 self.has_next = wp.has_next 

1282 return True 

1283 

1284 def _read_result_packet(self, first_packet): 

1285 self.field_count = first_packet.read_length_encoded_integer() 

1286 self._get_descriptions() 

1287 self._read_rowdata_packet() 

1288 

1289 def _read_rowdata_packet_unbuffered(self): 

1290 # Check if in an active query 

1291 if not self.unbuffered_active: 

1292 return 

1293 

1294 # EOF 

1295 packet = self.connection._read_packet() 

1296 if self._check_packet_is_eof(packet): 

1297 self.unbuffered_active = False 

1298 self.connection = None 

1299 self.rows = None 

1300 return 

1301 

1302 row = self._read_row_from_packet(packet) 

1303 self.affected_rows = 1 

1304 self.rows = (row,) # rows should tuple of row for MySQL-python compatibility. 

1305 return row 

1306 

1307 def _finish_unbuffered_query(self): 

1308 # After much reading on the MySQL protocol, it appears that there is, 

1309 # in fact, no way to stop MySQL from sending all the data after 

1310 # executing a query, so we just spin, and wait for an EOF packet. 

1311 while self.unbuffered_active: 

1312 try: 

1313 packet = self.connection._read_packet() 

1314 except err.OperationalError as e: 

1315 if e.args[0] in ( 

1316 ER.QUERY_TIMEOUT, 

1317 ER.STATEMENT_TIMEOUT, 

1318 ): 

1319 # if the query timed out we can simply ignore this error 

1320 self.unbuffered_active = False 

1321 self.connection = None 

1322 return 

1323 

1324 raise 

1325 

1326 if self._check_packet_is_eof(packet): 

1327 self.unbuffered_active = False 

1328 self.connection = None # release reference to kill cyclic reference. 

1329 

1330 def _read_rowdata_packet(self): 

1331 """Read a rowdata packet for each data row in the result set.""" 

1332 rows = [] 

1333 while True: 

1334 packet = self.connection._read_packet() 

1335 if self._check_packet_is_eof(packet): 

1336 self.connection = None # release reference to kill cyclic reference. 

1337 break 

1338 rows.append(self._read_row_from_packet(packet)) 

1339 

1340 self.affected_rows = len(rows) 

1341 self.rows = tuple(rows) 

1342 

1343 def _read_row_from_packet(self, packet): 

1344 row = [] 

1345 for encoding, converter in self.converters: 

1346 try: 

1347 data = packet.read_length_coded_string() 

1348 except IndexError: 

1349 # No more columns in this row 

1350 # See https://github.com/PyMySQL/PyMySQL/pull/434 

1351 break 

1352 if data is not None: 

1353 if encoding is not None: 

1354 data = data.decode(encoding) 

1355 if DEBUG: 

1356 print("DEBUG: DATA = ", data) 

1357 if converter is not None: 

1358 data = converter(data) 

1359 row.append(data) 

1360 return tuple(row) 

1361 

1362 def _get_descriptions(self): 

1363 """Read a column descriptor packet for each column in the result.""" 

1364 self.fields = [] 

1365 self.converters = [] 

1366 use_unicode = self.connection.use_unicode 

1367 conn_encoding = self.connection.encoding 

1368 description = [] 

1369 

1370 for i in range(self.field_count): 

1371 field = self.connection._read_packet(FieldDescriptorPacket) 

1372 self.fields.append(field) 

1373 description.append(field.description()) 

1374 field_type = field.type_code 

1375 if use_unicode: 

1376 if field_type == FIELD_TYPE.JSON: 

1377 # When SELECT from JSON column: charset = binary 

1378 # When SELECT CAST(... AS JSON): charset = connection encoding 

1379 # This behavior is different from TEXT / BLOB. 

1380 # We should decode result by connection encoding regardless charsetnr. 

1381 # See https://github.com/PyMySQL/PyMySQL/issues/488 

1382 encoding = conn_encoding # SELECT CAST(... AS JSON) 

1383 elif field_type in TEXT_TYPES: 

1384 if field.charsetnr == 63: # binary 

1385 # TEXTs with charset=binary means BINARY types. 

1386 encoding = None 

1387 else: 

1388 encoding = conn_encoding 

1389 else: 

1390 # Integers, Dates and Times, and other basic data is encoded in ascii 

1391 encoding = "ascii" 

1392 else: 

1393 encoding = None 

1394 converter = self.connection.decoders.get(field_type) 

1395 if converter is converters.through: 

1396 converter = None 

1397 if DEBUG: 

1398 print(f"DEBUG: field={field}, converter={converter}") 

1399 self.converters.append((encoding, converter)) 

1400 

1401 eof_packet = self.connection._read_packet() 

1402 assert eof_packet.is_eof_packet(), "Protocol error, expecting EOF" 

1403 self.description = tuple(description) 

1404 

1405 

1406class LoadLocalFile: 

1407 def __init__(self, filename, connection): 

1408 self.filename = filename 

1409 self.connection = connection 

1410 

1411 def send_data(self): 

1412 """Send data packets from the local file to the server""" 

1413 if not self.connection._sock: 

1414 raise err.InterfaceError(0, "") 

1415 conn: Connection = self.connection 

1416 

1417 try: 

1418 with open(self.filename, "rb") as open_file: 

1419 packet_size = min( 

1420 conn.max_allowed_packet, 16 * 1024 

1421 ) # 16KB is efficient enough 

1422 while True: 

1423 chunk = open_file.read(packet_size) 

1424 if not chunk: 

1425 break 

1426 conn.write_packet(chunk) 

1427 except OSError: 

1428 raise err.OperationalError( 

1429 ER.FILE_NOT_FOUND, 

1430 f"Can't find file '{self.filename}'", 

1431 ) 

1432 finally: 

1433 if not conn._closed: 

1434 # send the empty packet to signify we are done sending data 

1435 conn.write_packet(b"")