Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/sqlalchemy/dialects/postgresql/asyncpg.py: 42%

534 statements  

« prev     ^ index     » next       coverage.py v7.0.1, created at 2022-12-25 06:11 +0000

1# postgresql/asyncpg.py 

2# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors <see AUTHORS 

3# file> 

4# 

5# This module is part of SQLAlchemy and is released under 

6# the MIT License: https://www.opensource.org/licenses/mit-license.php 

7r""" 

8.. dialect:: postgresql+asyncpg 

9 :name: asyncpg 

10 :dbapi: asyncpg 

11 :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...] 

12 :url: https://magicstack.github.io/asyncpg/ 

13 

14The asyncpg dialect is SQLAlchemy's first Python asyncio dialect. 

15 

16Using a special asyncio mediation layer, the asyncpg dialect is usable 

17as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>` 

18extension package. 

19 

20This dialect should normally be used only with the 

21:func:`_asyncio.create_async_engine` engine creation function:: 

22 

23 from sqlalchemy.ext.asyncio import create_async_engine 

24 engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname") 

25 

26The dialect can also be run as a "synchronous" dialect within the 

27:func:`_sa.create_engine` function, which will pass "await" calls into 

28an ad-hoc event loop. This mode of operation is of **limited use** 

29and is for special testing scenarios only. The mode can be enabled by 

30adding the SQLAlchemy-specific flag ``async_fallback`` to the URL 

31in conjunction with :func:`_sa.create_engine`:: 

32 

33 # for testing purposes only; do not use in production! 

34 engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true") 

35 

36 

37.. versionadded:: 1.4 

38 

39.. note:: 

40 

41 By default asyncpg does not decode the ``json`` and ``jsonb`` types and 

42 returns them as strings. SQLAlchemy sets default type decoder for ``json`` 

43 and ``jsonb`` types using the python builtin ``json.loads`` function. 

44 The json implementation used can be changed by setting the attribute 

45 ``json_deserializer`` when creating the engine with 

46 :func:`create_engine` or :func:`create_async_engine`. 

47 

48 

49.. _asyncpg_prepared_statement_cache: 

50 

51Prepared Statement Cache 

52-------------------------- 

53 

54The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()`` 

55for all statements. The prepared statement objects are cached after 

56construction which appears to grant a 10% or more performance improvement for 

57statement invocation. The cache is on a per-DBAPI connection basis, which 

58means that the primary storage for prepared statements is within DBAPI 

59connections pooled within the connection pool. The size of this cache 

60defaults to 100 statements per DBAPI connection and may be adjusted using the 

61``prepared_statement_cache_size`` DBAPI argument (note that while this argument 

62is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the 

63asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect 

64argument):: 

65 

66 

67 engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500") 

68 

69To disable the prepared statement cache, use a value of zero:: 

70 

71 engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0") 

72 

73.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg. 

74 

75 

76.. warning:: The ``asyncpg`` database driver necessarily uses caches for 

77 PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes 

78 such as ``ENUM`` objects are changed via DDL operations. Additionally, 

79 prepared statements themselves which are optionally cached by SQLAlchemy's 

80 driver as described above may also become "stale" when DDL has been emitted 

81 to the PostgreSQL database which modifies the tables or other objects 

82 involved in a particular prepared statement. 

83 

84 The SQLAlchemy asyncpg dialect will invalidate these caches within its local 

85 process when statements that represent DDL are emitted on a local 

86 connection, but this is only controllable within a single Python process / 

87 database engine. If DDL changes are made from other database engines 

88 and/or processes, a running application may encounter asyncpg exceptions 

89 ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup 

90 failed for type <oid>")`` if it refers to pooled database connections which 

91 operated upon the previous structures. The SQLAlchemy asyncpg dialect will 

92 recover from these error cases when the driver raises these exceptions by 

93 clearing its internal caches as well as those of the asyncpg driver in 

94 response to them, but cannot prevent them from being raised in the first 

95 place if the cached prepared statement or asyncpg type caches have gone 

96 stale, nor can it retry the statement as the PostgreSQL transaction is 

97 invalidated when these errors occur. 

98 

99Disabling the PostgreSQL JIT to improve ENUM datatype handling 

100--------------------------------------------------------------- 

101 

102Asyncpg has an `issue <https://github.com/MagicStack/asyncpg/issues/727>`_ when 

103using PostgreSQL ENUM datatypes, where upon the creation of new database 

104connections, an expensive query may be emitted in order to retrieve metadata 

105regarding custom types which has been shown to negatively affect performance. 

106To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the 

107client using this setting passed to :func:`_asyncio.create_async_engine`:: 

108 

109 engine = create_async_engine( 

110 "postgresql+asyncpg://user:password@localhost/tmp", 

111 connect_args={"server_settings": {"jit": "off"}}, 

112 ) 

113 

114.. seealso:: 

115 

116 https://github.com/MagicStack/asyncpg/issues/727 

117 

118""" # noqa 

119 

120import collections 

121import decimal 

122import json as _py_json 

123import re 

124import time 

125 

126from . import json 

127from .base import _DECIMAL_TYPES 

128from .base import _FLOAT_TYPES 

129from .base import _INT_TYPES 

130from .base import ENUM 

131from .base import INTERVAL 

132from .base import OID 

133from .base import PGCompiler 

134from .base import PGDialect 

135from .base import PGExecutionContext 

136from .base import PGIdentifierPreparer 

137from .base import REGCLASS 

138from .base import UUID 

139from ... import exc 

140from ... import pool 

141from ... import processors 

142from ... import util 

143from ...engine import AdaptedConnection 

144from ...sql import sqltypes 

145from ...util.concurrency import asyncio 

146from ...util.concurrency import await_fallback 

147from ...util.concurrency import await_only 

148 

149 

150try: 

151 from uuid import UUID as _python_UUID # noqa 

152except ImportError: 

153 _python_UUID = None 

154 

155 

156class AsyncpgTime(sqltypes.Time): 

157 def get_dbapi_type(self, dbapi): 

158 if self.timezone: 

159 return dbapi.TIME_W_TZ 

160 else: 

161 return dbapi.TIME 

162 

163 

164class AsyncpgDate(sqltypes.Date): 

165 def get_dbapi_type(self, dbapi): 

166 return dbapi.DATE 

167 

168 

169class AsyncpgDateTime(sqltypes.DateTime): 

170 def get_dbapi_type(self, dbapi): 

171 if self.timezone: 

172 return dbapi.TIMESTAMP_W_TZ 

173 else: 

174 return dbapi.TIMESTAMP 

175 

176 

177class AsyncpgBoolean(sqltypes.Boolean): 

178 def get_dbapi_type(self, dbapi): 

179 return dbapi.BOOLEAN 

180 

181 

182class AsyncPgInterval(INTERVAL): 

183 def get_dbapi_type(self, dbapi): 

184 return dbapi.INTERVAL 

185 

186 @classmethod 

187 def adapt_emulated_to_native(cls, interval, **kw): 

188 

189 return AsyncPgInterval(precision=interval.second_precision) 

190 

191 

192class AsyncPgEnum(ENUM): 

193 def get_dbapi_type(self, dbapi): 

194 return dbapi.ENUM 

195 

196 

197class AsyncpgInteger(sqltypes.Integer): 

198 def get_dbapi_type(self, dbapi): 

199 return dbapi.INTEGER 

200 

201 

202class AsyncpgBigInteger(sqltypes.BigInteger): 

203 def get_dbapi_type(self, dbapi): 

204 return dbapi.BIGINTEGER 

205 

206 

207class AsyncpgJSON(json.JSON): 

208 def get_dbapi_type(self, dbapi): 

209 return dbapi.JSON 

210 

211 def result_processor(self, dialect, coltype): 

212 return None 

213 

214 

215class AsyncpgJSONB(json.JSONB): 

216 def get_dbapi_type(self, dbapi): 

217 return dbapi.JSONB 

218 

219 def result_processor(self, dialect, coltype): 

220 return None 

221 

222 

223class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType): 

224 def get_dbapi_type(self, dbapi): 

225 raise NotImplementedError("should not be here") 

226 

227 

228class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): 

229 def get_dbapi_type(self, dbapi): 

230 return dbapi.INTEGER 

231 

232 

233class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): 

234 def get_dbapi_type(self, dbapi): 

235 return dbapi.STRING 

236 

237 

238class AsyncpgJSONPathType(json.JSONPathType): 

239 def bind_processor(self, dialect): 

240 def process(value): 

241 assert isinstance(value, util.collections_abc.Sequence) 

242 tokens = [util.text_type(elem) for elem in value] 

243 return tokens 

244 

245 return process 

246 

247 

248class AsyncpgUUID(UUID): 

249 def get_dbapi_type(self, dbapi): 

250 return dbapi.UUID 

251 

252 def bind_processor(self, dialect): 

253 if not self.as_uuid and dialect.use_native_uuid: 

254 

255 def process(value): 

256 if value is not None: 

257 value = _python_UUID(value) 

258 return value 

259 

260 return process 

261 

262 def result_processor(self, dialect, coltype): 

263 if not self.as_uuid and dialect.use_native_uuid: 

264 

265 def process(value): 

266 if value is not None: 

267 value = str(value) 

268 return value 

269 

270 return process 

271 

272 

273class AsyncpgNumeric(sqltypes.Numeric): 

274 def get_dbapi_type(self, dbapi): 

275 return dbapi.NUMBER 

276 

277 def bind_processor(self, dialect): 

278 return None 

279 

280 def result_processor(self, dialect, coltype): 

281 if self.asdecimal: 

282 if coltype in _FLOAT_TYPES: 

283 return processors.to_decimal_processor_factory( 

284 decimal.Decimal, self._effective_decimal_return_scale 

285 ) 

286 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: 

287 # pg8000 returns Decimal natively for 1700 

288 return None 

289 else: 

290 raise exc.InvalidRequestError( 

291 "Unknown PG numeric type: %d" % coltype 

292 ) 

293 else: 

294 if coltype in _FLOAT_TYPES: 

295 # pg8000 returns float natively for 701 

296 return None 

297 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: 

298 return processors.to_float 

299 else: 

300 raise exc.InvalidRequestError( 

301 "Unknown PG numeric type: %d" % coltype 

302 ) 

303 

304 

305class AsyncpgFloat(AsyncpgNumeric): 

306 def get_dbapi_type(self, dbapi): 

307 return dbapi.FLOAT 

308 

309 

310class AsyncpgREGCLASS(REGCLASS): 

311 def get_dbapi_type(self, dbapi): 

312 return dbapi.STRING 

313 

314 

315class AsyncpgOID(OID): 

316 def get_dbapi_type(self, dbapi): 

317 return dbapi.INTEGER 

318 

319 

320class PGExecutionContext_asyncpg(PGExecutionContext): 

321 def handle_dbapi_exception(self, e): 

322 if isinstance( 

323 e, 

324 ( 

325 self.dialect.dbapi.InvalidCachedStatementError, 

326 self.dialect.dbapi.InternalServerError, 

327 ), 

328 ): 

329 self.dialect._invalidate_schema_cache() 

330 

331 def pre_exec(self): 

332 if self.isddl: 

333 self.dialect._invalidate_schema_cache() 

334 

335 self.cursor._invalidate_schema_cache_asof = ( 

336 self.dialect._invalidate_schema_cache_asof 

337 ) 

338 

339 if not self.compiled: 

340 return 

341 

342 # we have to exclude ENUM because "enum" not really a "type" 

343 # we can cast to, it has to be the name of the type itself. 

344 # for now we just omit it from casting 

345 self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM} 

346 

347 def create_server_side_cursor(self): 

348 return self._dbapi_connection.cursor(server_side=True) 

349 

350 

351class PGCompiler_asyncpg(PGCompiler): 

352 pass 

353 

354 

355class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer): 

356 pass 

357 

358 

359class AsyncAdapt_asyncpg_cursor: 

360 __slots__ = ( 

361 "_adapt_connection", 

362 "_connection", 

363 "_rows", 

364 "description", 

365 "arraysize", 

366 "rowcount", 

367 "_inputsizes", 

368 "_cursor", 

369 "_invalidate_schema_cache_asof", 

370 ) 

371 

372 server_side = False 

373 

374 def __init__(self, adapt_connection): 

375 self._adapt_connection = adapt_connection 

376 self._connection = adapt_connection._connection 

377 self._rows = [] 

378 self._cursor = None 

379 self.description = None 

380 self.arraysize = 1 

381 self.rowcount = -1 

382 self._inputsizes = None 

383 self._invalidate_schema_cache_asof = 0 

384 

385 def close(self): 

386 self._rows[:] = [] 

387 

388 def _handle_exception(self, error): 

389 self._adapt_connection._handle_exception(error) 

390 

391 def _parameter_placeholders(self, params): 

392 if not self._inputsizes: 

393 return tuple("$%d" % idx for idx, _ in enumerate(params, 1)) 

394 else: 

395 return tuple( 

396 "$%d::%s" % (idx, typ) if typ else "$%d" % idx 

397 for idx, typ in enumerate( 

398 (_pg_types.get(typ) for typ in self._inputsizes), 1 

399 ) 

400 ) 

401 

402 async def _prepare_and_execute(self, operation, parameters): 

403 adapt_connection = self._adapt_connection 

404 

405 async with adapt_connection._execute_mutex: 

406 

407 if not adapt_connection._started: 

408 await adapt_connection._start_transaction() 

409 

410 if parameters is not None: 

411 operation = operation % self._parameter_placeholders( 

412 parameters 

413 ) 

414 else: 

415 parameters = () 

416 

417 try: 

418 prepared_stmt, attributes = await adapt_connection._prepare( 

419 operation, self._invalidate_schema_cache_asof 

420 ) 

421 

422 if attributes: 

423 self.description = [ 

424 ( 

425 attr.name, 

426 attr.type.oid, 

427 None, 

428 None, 

429 None, 

430 None, 

431 None, 

432 ) 

433 for attr in attributes 

434 ] 

435 else: 

436 self.description = None 

437 

438 if self.server_side: 

439 self._cursor = await prepared_stmt.cursor(*parameters) 

440 self.rowcount = -1 

441 else: 

442 self._rows = await prepared_stmt.fetch(*parameters) 

443 status = prepared_stmt.get_statusmsg() 

444 

445 reg = re.match( 

446 r"(?:UPDATE|DELETE|INSERT \d+) (\d+)", status 

447 ) 

448 if reg: 

449 self.rowcount = int(reg.group(1)) 

450 else: 

451 self.rowcount = -1 

452 

453 except Exception as error: 

454 self._handle_exception(error) 

455 

456 async def _executemany(self, operation, seq_of_parameters): 

457 adapt_connection = self._adapt_connection 

458 

459 async with adapt_connection._execute_mutex: 

460 await adapt_connection._check_type_cache_invalidation( 

461 self._invalidate_schema_cache_asof 

462 ) 

463 

464 if not adapt_connection._started: 

465 await adapt_connection._start_transaction() 

466 

467 operation = operation % self._parameter_placeholders( 

468 seq_of_parameters[0] 

469 ) 

470 

471 try: 

472 return await self._connection.executemany( 

473 operation, seq_of_parameters 

474 ) 

475 except Exception as error: 

476 self._handle_exception(error) 

477 

478 def execute(self, operation, parameters=None): 

479 self._adapt_connection.await_( 

480 self._prepare_and_execute(operation, parameters) 

481 ) 

482 

483 def executemany(self, operation, seq_of_parameters): 

484 return self._adapt_connection.await_( 

485 self._executemany(operation, seq_of_parameters) 

486 ) 

487 

488 def setinputsizes(self, *inputsizes): 

489 self._inputsizes = inputsizes 

490 

491 def __iter__(self): 

492 while self._rows: 

493 yield self._rows.pop(0) 

494 

495 def fetchone(self): 

496 if self._rows: 

497 return self._rows.pop(0) 

498 else: 

499 return None 

500 

501 def fetchmany(self, size=None): 

502 if size is None: 

503 size = self.arraysize 

504 

505 retval = self._rows[0:size] 

506 self._rows[:] = self._rows[size:] 

507 return retval 

508 

509 def fetchall(self): 

510 retval = self._rows[:] 

511 self._rows[:] = [] 

512 return retval 

513 

514 

515class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor): 

516 

517 server_side = True 

518 __slots__ = ("_rowbuffer",) 

519 

520 def __init__(self, adapt_connection): 

521 super(AsyncAdapt_asyncpg_ss_cursor, self).__init__(adapt_connection) 

522 self._rowbuffer = None 

523 

524 def close(self): 

525 self._cursor = None 

526 self._rowbuffer = None 

527 

528 def _buffer_rows(self): 

529 new_rows = self._adapt_connection.await_(self._cursor.fetch(50)) 

530 self._rowbuffer = collections.deque(new_rows) 

531 

532 def __aiter__(self): 

533 return self 

534 

535 async def __anext__(self): 

536 if not self._rowbuffer: 

537 self._buffer_rows() 

538 

539 while True: 

540 while self._rowbuffer: 

541 yield self._rowbuffer.popleft() 

542 

543 self._buffer_rows() 

544 if not self._rowbuffer: 

545 break 

546 

547 def fetchone(self): 

548 if not self._rowbuffer: 

549 self._buffer_rows() 

550 if not self._rowbuffer: 

551 return None 

552 return self._rowbuffer.popleft() 

553 

554 def fetchmany(self, size=None): 

555 if size is None: 

556 return self.fetchall() 

557 

558 if not self._rowbuffer: 

559 self._buffer_rows() 

560 

561 buf = list(self._rowbuffer) 

562 lb = len(buf) 

563 if size > lb: 

564 buf.extend( 

565 self._adapt_connection.await_(self._cursor.fetch(size - lb)) 

566 ) 

567 

568 result = buf[0:size] 

569 self._rowbuffer = collections.deque(buf[size:]) 

570 return result 

571 

572 def fetchall(self): 

573 ret = list(self._rowbuffer) + list( 

574 self._adapt_connection.await_(self._all()) 

575 ) 

576 self._rowbuffer.clear() 

577 return ret 

578 

579 async def _all(self): 

580 rows = [] 

581 

582 # TODO: looks like we have to hand-roll some kind of batching here. 

583 # hardcoding for the moment but this should be improved. 

584 while True: 

585 batch = await self._cursor.fetch(1000) 

586 if batch: 

587 rows.extend(batch) 

588 continue 

589 else: 

590 break 

591 return rows 

592 

593 def executemany(self, operation, seq_of_parameters): 

594 raise NotImplementedError( 

595 "server side cursor doesn't support executemany yet" 

596 ) 

597 

598 

599class AsyncAdapt_asyncpg_connection(AdaptedConnection): 

600 __slots__ = ( 

601 "dbapi", 

602 "_connection", 

603 "isolation_level", 

604 "_isolation_setting", 

605 "readonly", 

606 "deferrable", 

607 "_transaction", 

608 "_started", 

609 "_prepared_statement_cache", 

610 "_invalidate_schema_cache_asof", 

611 "_execute_mutex", 

612 ) 

613 

614 await_ = staticmethod(await_only) 

615 

616 def __init__(self, dbapi, connection, prepared_statement_cache_size=100): 

617 self.dbapi = dbapi 

618 self._connection = connection 

619 self.isolation_level = self._isolation_setting = "read_committed" 

620 self.readonly = False 

621 self.deferrable = False 

622 self._transaction = None 

623 self._started = False 

624 self._invalidate_schema_cache_asof = time.time() 

625 self._execute_mutex = asyncio.Lock() 

626 

627 if prepared_statement_cache_size: 

628 self._prepared_statement_cache = util.LRUCache( 

629 prepared_statement_cache_size 

630 ) 

631 else: 

632 self._prepared_statement_cache = None 

633 

634 async def _check_type_cache_invalidation(self, invalidate_timestamp): 

635 if invalidate_timestamp > self._invalidate_schema_cache_asof: 

636 await self._connection.reload_schema_state() 

637 self._invalidate_schema_cache_asof = invalidate_timestamp 

638 

639 async def _prepare(self, operation, invalidate_timestamp): 

640 await self._check_type_cache_invalidation(invalidate_timestamp) 

641 

642 cache = self._prepared_statement_cache 

643 if cache is None: 

644 prepared_stmt = await self._connection.prepare(operation) 

645 attributes = prepared_stmt.get_attributes() 

646 return prepared_stmt, attributes 

647 

648 # asyncpg uses a type cache for the "attributes" which seems to go 

649 # stale independently of the PreparedStatement itself, so place that 

650 # collection in the cache as well. 

651 if operation in cache: 

652 prepared_stmt, attributes, cached_timestamp = cache[operation] 

653 

654 # preparedstatements themselves also go stale for certain DDL 

655 # changes such as size of a VARCHAR changing, so there is also 

656 # a cross-connection invalidation timestamp 

657 if cached_timestamp > invalidate_timestamp: 

658 return prepared_stmt, attributes 

659 

660 prepared_stmt = await self._connection.prepare(operation) 

661 attributes = prepared_stmt.get_attributes() 

662 cache[operation] = (prepared_stmt, attributes, time.time()) 

663 

664 return prepared_stmt, attributes 

665 

666 def _handle_exception(self, error): 

667 if self._connection.is_closed(): 

668 self._transaction = None 

669 self._started = False 

670 

671 if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error): 

672 exception_mapping = self.dbapi._asyncpg_error_translate 

673 

674 for super_ in type(error).__mro__: 

675 if super_ in exception_mapping: 

676 translated_error = exception_mapping[super_]( 

677 "%s: %s" % (type(error), error) 

678 ) 

679 translated_error.pgcode = ( 

680 translated_error.sqlstate 

681 ) = getattr(error, "sqlstate", None) 

682 raise translated_error from error 

683 else: 

684 raise error 

685 else: 

686 raise error 

687 

688 @property 

689 def autocommit(self): 

690 return self.isolation_level == "autocommit" 

691 

692 @autocommit.setter 

693 def autocommit(self, value): 

694 if value: 

695 self.isolation_level = "autocommit" 

696 else: 

697 self.isolation_level = self._isolation_setting 

698 

699 def set_isolation_level(self, level): 

700 if self._started: 

701 self.rollback() 

702 self.isolation_level = self._isolation_setting = level 

703 

704 async def _start_transaction(self): 

705 if self.isolation_level == "autocommit": 

706 return 

707 

708 try: 

709 self._transaction = self._connection.transaction( 

710 isolation=self.isolation_level, 

711 readonly=self.readonly, 

712 deferrable=self.deferrable, 

713 ) 

714 await self._transaction.start() 

715 except Exception as error: 

716 self._handle_exception(error) 

717 else: 

718 self._started = True 

719 

720 def cursor(self, server_side=False): 

721 if server_side: 

722 return AsyncAdapt_asyncpg_ss_cursor(self) 

723 else: 

724 return AsyncAdapt_asyncpg_cursor(self) 

725 

726 def rollback(self): 

727 if self._started: 

728 try: 

729 self.await_(self._transaction.rollback()) 

730 except Exception as error: 

731 self._handle_exception(error) 

732 finally: 

733 self._transaction = None 

734 self._started = False 

735 

736 def commit(self): 

737 if self._started: 

738 try: 

739 self.await_(self._transaction.commit()) 

740 except Exception as error: 

741 self._handle_exception(error) 

742 finally: 

743 self._transaction = None 

744 self._started = False 

745 

746 def close(self): 

747 self.rollback() 

748 

749 self.await_(self._connection.close()) 

750 

751 def terminate(self): 

752 self._connection.terminate() 

753 

754 

755class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection): 

756 __slots__ = () 

757 

758 await_ = staticmethod(await_fallback) 

759 

760 

761class AsyncAdapt_asyncpg_dbapi: 

762 def __init__(self, asyncpg): 

763 self.asyncpg = asyncpg 

764 self.paramstyle = "format" 

765 

766 def connect(self, *arg, **kw): 

767 async_fallback = kw.pop("async_fallback", False) 

768 prepared_statement_cache_size = kw.pop( 

769 "prepared_statement_cache_size", 100 

770 ) 

771 if util.asbool(async_fallback): 

772 return AsyncAdaptFallback_asyncpg_connection( 

773 self, 

774 await_fallback(self.asyncpg.connect(*arg, **kw)), 

775 prepared_statement_cache_size=prepared_statement_cache_size, 

776 ) 

777 else: 

778 return AsyncAdapt_asyncpg_connection( 

779 self, 

780 await_only(self.asyncpg.connect(*arg, **kw)), 

781 prepared_statement_cache_size=prepared_statement_cache_size, 

782 ) 

783 

784 class Error(Exception): 

785 pass 

786 

787 class Warning(Exception): # noqa 

788 pass 

789 

790 class InterfaceError(Error): 

791 pass 

792 

793 class DatabaseError(Error): 

794 pass 

795 

796 class InternalError(DatabaseError): 

797 pass 

798 

799 class OperationalError(DatabaseError): 

800 pass 

801 

802 class ProgrammingError(DatabaseError): 

803 pass 

804 

805 class IntegrityError(DatabaseError): 

806 pass 

807 

808 class DataError(DatabaseError): 

809 pass 

810 

811 class NotSupportedError(DatabaseError): 

812 pass 

813 

814 class InternalServerError(InternalError): 

815 pass 

816 

817 class InvalidCachedStatementError(NotSupportedError): 

818 def __init__(self, message): 

819 super( 

820 AsyncAdapt_asyncpg_dbapi.InvalidCachedStatementError, self 

821 ).__init__( 

822 message + " (SQLAlchemy asyncpg dialect will now invalidate " 

823 "all prepared caches in response to this exception)", 

824 ) 

825 

826 @util.memoized_property 

827 def _asyncpg_error_translate(self): 

828 import asyncpg 

829 

830 return { 

831 asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501 

832 asyncpg.exceptions.PostgresError: self.Error, 

833 asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError, 

834 asyncpg.exceptions.InterfaceError: self.InterfaceError, 

835 asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501 

836 asyncpg.exceptions.InternalServerError: self.InternalServerError, 

837 } 

838 

839 def Binary(self, value): 

840 return value 

841 

842 STRING = util.symbol("STRING") 

843 TIMESTAMP = util.symbol("TIMESTAMP") 

844 TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ") 

845 TIME = util.symbol("TIME") 

846 TIME_W_TZ = util.symbol("TIME_W_TZ") 

847 DATE = util.symbol("DATE") 

848 INTERVAL = util.symbol("INTERVAL") 

849 NUMBER = util.symbol("NUMBER") 

850 FLOAT = util.symbol("FLOAT") 

851 BOOLEAN = util.symbol("BOOLEAN") 

852 INTEGER = util.symbol("INTEGER") 

853 BIGINTEGER = util.symbol("BIGINTEGER") 

854 BYTES = util.symbol("BYTES") 

855 DECIMAL = util.symbol("DECIMAL") 

856 JSON = util.symbol("JSON") 

857 JSONB = util.symbol("JSONB") 

858 ENUM = util.symbol("ENUM") 

859 UUID = util.symbol("UUID") 

860 BYTEA = util.symbol("BYTEA") 

861 

862 DATETIME = TIMESTAMP 

863 BINARY = BYTEA 

864 

865 

866_pg_types = { 

867 AsyncAdapt_asyncpg_dbapi.STRING: "varchar", 

868 AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp", 

869 AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone", 

870 AsyncAdapt_asyncpg_dbapi.DATE: "date", 

871 AsyncAdapt_asyncpg_dbapi.TIME: "time", 

872 AsyncAdapt_asyncpg_dbapi.TIME_W_TZ: "time with time zone", 

873 AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval", 

874 AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric", 

875 AsyncAdapt_asyncpg_dbapi.FLOAT: "float", 

876 AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool", 

877 AsyncAdapt_asyncpg_dbapi.INTEGER: "integer", 

878 AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint", 

879 AsyncAdapt_asyncpg_dbapi.BYTES: "bytes", 

880 AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal", 

881 AsyncAdapt_asyncpg_dbapi.JSON: "json", 

882 AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb", 

883 AsyncAdapt_asyncpg_dbapi.ENUM: "enum", 

884 AsyncAdapt_asyncpg_dbapi.UUID: "uuid", 

885 AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea", 

886} 

887 

888 

889class PGDialect_asyncpg(PGDialect): 

890 driver = "asyncpg" 

891 supports_statement_cache = True 

892 

893 supports_unicode_statements = True 

894 supports_server_side_cursors = True 

895 

896 supports_unicode_binds = True 

897 has_terminate = True 

898 

899 default_paramstyle = "format" 

900 supports_sane_multi_rowcount = False 

901 execution_ctx_cls = PGExecutionContext_asyncpg 

902 statement_compiler = PGCompiler_asyncpg 

903 preparer = PGIdentifierPreparer_asyncpg 

904 

905 use_setinputsizes = True 

906 

907 use_native_uuid = True 

908 

909 colspecs = util.update_copy( 

910 PGDialect.colspecs, 

911 { 

912 sqltypes.Time: AsyncpgTime, 

913 sqltypes.Date: AsyncpgDate, 

914 sqltypes.DateTime: AsyncpgDateTime, 

915 sqltypes.Interval: AsyncPgInterval, 

916 INTERVAL: AsyncPgInterval, 

917 UUID: AsyncpgUUID, 

918 sqltypes.Boolean: AsyncpgBoolean, 

919 sqltypes.Integer: AsyncpgInteger, 

920 sqltypes.BigInteger: AsyncpgBigInteger, 

921 sqltypes.Numeric: AsyncpgNumeric, 

922 sqltypes.Float: AsyncpgFloat, 

923 sqltypes.JSON: AsyncpgJSON, 

924 json.JSONB: AsyncpgJSONB, 

925 sqltypes.JSON.JSONPathType: AsyncpgJSONPathType, 

926 sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType, 

927 sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType, 

928 sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType, 

929 sqltypes.Enum: AsyncPgEnum, 

930 OID: AsyncpgOID, 

931 REGCLASS: AsyncpgREGCLASS, 

932 }, 

933 ) 

934 is_async = True 

935 _invalidate_schema_cache_asof = 0 

936 

937 def _invalidate_schema_cache(self): 

938 self._invalidate_schema_cache_asof = time.time() 

939 

940 @util.memoized_property 

941 def _dbapi_version(self): 

942 if self.dbapi and hasattr(self.dbapi, "__version__"): 

943 return tuple( 

944 [ 

945 int(x) 

946 for x in re.findall( 

947 r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ 

948 ) 

949 ] 

950 ) 

951 else: 

952 return (99, 99, 99) 

953 

954 @classmethod 

955 def dbapi(cls): 

956 return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg")) 

957 

958 @util.memoized_property 

959 def _isolation_lookup(self): 

960 return { 

961 "AUTOCOMMIT": "autocommit", 

962 "READ COMMITTED": "read_committed", 

963 "REPEATABLE READ": "repeatable_read", 

964 "SERIALIZABLE": "serializable", 

965 } 

966 

967 def set_isolation_level(self, connection, level): 

968 try: 

969 level = self._isolation_lookup[level.replace("_", " ")] 

970 except KeyError as err: 

971 util.raise_( 

972 exc.ArgumentError( 

973 "Invalid value '%s' for isolation_level. " 

974 "Valid isolation levels for %s are %s" 

975 % (level, self.name, ", ".join(self._isolation_lookup)) 

976 ), 

977 replace_context=err, 

978 ) 

979 

980 connection.set_isolation_level(level) 

981 

982 def set_readonly(self, connection, value): 

983 connection.readonly = value 

984 

985 def get_readonly(self, connection): 

986 return connection.readonly 

987 

988 def set_deferrable(self, connection, value): 

989 connection.deferrable = value 

990 

991 def get_deferrable(self, connection): 

992 return connection.deferrable 

993 

994 def do_terminate(self, dbapi_connection) -> None: 

995 dbapi_connection.terminate() 

996 

997 def create_connect_args(self, url): 

998 opts = url.translate_connect_args(username="user") 

999 

1000 opts.update(url.query) 

1001 util.coerce_kw_type(opts, "prepared_statement_cache_size", int) 

1002 util.coerce_kw_type(opts, "port", int) 

1003 return ([], opts) 

1004 

1005 @classmethod 

1006 def get_pool_class(cls, url): 

1007 

1008 async_fallback = url.query.get("async_fallback", False) 

1009 

1010 if util.asbool(async_fallback): 

1011 return pool.FallbackAsyncAdaptedQueuePool 

1012 else: 

1013 return pool.AsyncAdaptedQueuePool 

1014 

1015 def is_disconnect(self, e, connection, cursor): 

1016 if connection: 

1017 return connection._connection.is_closed() 

1018 else: 

1019 return isinstance( 

1020 e, self.dbapi.InterfaceError 

1021 ) and "connection is closed" in str(e) 

1022 

1023 def do_set_input_sizes(self, cursor, list_of_tuples, context): 

1024 if self.positional: 

1025 cursor.setinputsizes( 

1026 *[dbtype for key, dbtype, sqltype in list_of_tuples] 

1027 ) 

1028 else: 

1029 cursor.setinputsizes( 

1030 **{ 

1031 key: dbtype 

1032 for key, dbtype, sqltype in list_of_tuples 

1033 if dbtype 

1034 } 

1035 ) 

1036 

1037 async def setup_asyncpg_json_codec(self, conn): 

1038 """set up JSON codec for asyncpg. 

1039 

1040 This occurs for all new connections and 

1041 can be overridden by third party dialects. 

1042 

1043 .. versionadded:: 1.4.27 

1044 

1045 """ 

1046 

1047 asyncpg_connection = conn._connection 

1048 deserializer = self._json_deserializer or _py_json.loads 

1049 

1050 def _json_decoder(bin_value): 

1051 return deserializer(bin_value.decode()) 

1052 

1053 await asyncpg_connection.set_type_codec( 

1054 "json", 

1055 encoder=str.encode, 

1056 decoder=_json_decoder, 

1057 schema="pg_catalog", 

1058 format="binary", 

1059 ) 

1060 

1061 async def setup_asyncpg_jsonb_codec(self, conn): 

1062 """set up JSONB codec for asyncpg. 

1063 

1064 This occurs for all new connections and 

1065 can be overridden by third party dialects. 

1066 

1067 .. versionadded:: 1.4.27 

1068 

1069 """ 

1070 

1071 asyncpg_connection = conn._connection 

1072 deserializer = self._json_deserializer or _py_json.loads 

1073 

1074 def _jsonb_encoder(str_value): 

1075 # \x01 is the prefix for jsonb used by PostgreSQL. 

1076 # asyncpg requires it when format='binary' 

1077 return b"\x01" + str_value.encode() 

1078 

1079 deserializer = self._json_deserializer or _py_json.loads 

1080 

1081 def _jsonb_decoder(bin_value): 

1082 # the byte is the \x01 prefix for jsonb used by PostgreSQL. 

1083 # asyncpg returns it when format='binary' 

1084 return deserializer(bin_value[1:].decode()) 

1085 

1086 await asyncpg_connection.set_type_codec( 

1087 "jsonb", 

1088 encoder=_jsonb_encoder, 

1089 decoder=_jsonb_decoder, 

1090 schema="pg_catalog", 

1091 format="binary", 

1092 ) 

1093 

1094 def on_connect(self): 

1095 """on_connect for asyncpg 

1096 

1097 A major component of this for asyncpg is to set up type decoders at the 

1098 asyncpg level. 

1099 

1100 See https://github.com/MagicStack/asyncpg/issues/623 for 

1101 notes on JSON/JSONB implementation. 

1102 

1103 """ 

1104 

1105 super_connect = super(PGDialect_asyncpg, self).on_connect() 

1106 

1107 def connect(conn): 

1108 conn.await_(self.setup_asyncpg_json_codec(conn)) 

1109 conn.await_(self.setup_asyncpg_jsonb_codec(conn)) 

1110 if super_connect is not None: 

1111 super_connect(conn) 

1112 

1113 return connect 

1114 

1115 def get_driver_connection(self, connection): 

1116 return connection._connection 

1117 

1118 

1119dialect = PGDialect_asyncpg