1# dialects/postgresql/asyncpg.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors <see AUTHORS
3# file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9r"""
10.. dialect:: postgresql+asyncpg
11 :name: asyncpg
12 :dbapi: asyncpg
13 :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...]
14 :url: https://magicstack.github.io/asyncpg/
15
16The asyncpg dialect is SQLAlchemy's first Python asyncio dialect.
17
18Using a special asyncio mediation layer, the asyncpg dialect is usable
19as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
20extension package.
21
22This dialect should normally be used only with the
23:func:`_asyncio.create_async_engine` engine creation function::
24
25 from sqlalchemy.ext.asyncio import create_async_engine
26
27 engine = create_async_engine(
28 "postgresql+asyncpg://user:pass@hostname/dbname"
29 )
30
31.. versionadded:: 1.4
32
33.. note::
34
35 By default asyncpg does not decode the ``json`` and ``jsonb`` types and
36 returns them as strings. SQLAlchemy sets default type decoder for ``json``
37 and ``jsonb`` types using the python builtin ``json.loads`` function.
38 The json implementation used can be changed by setting the attribute
39 ``json_deserializer`` when creating the engine with
40 :func:`create_engine` or :func:`create_async_engine`.
41
42.. _asyncpg_multihost:
43
44Multihost Connections
45--------------------------
46
47The asyncpg dialect features support for multiple fallback hosts in the
48same way as that of the psycopg2 and psycopg dialects. The
49syntax is the same,
50using ``host=<host>:<port>`` combinations as additional query string arguments;
51however, there is no default port, so all hosts must have a complete port number
52present, otherwise an exception is raised::
53
54 engine = create_async_engine(
55 "postgresql+asyncpg://user:password@/dbname?host=HostA:5432&host=HostB:5432&host=HostC:5432"
56 )
57
58For complete background on this syntax, see :ref:`psycopg2_multi_host`.
59
60.. versionadded:: 2.0.18
61
62.. seealso::
63
64 :ref:`psycopg2_multi_host`
65
66.. _asyncpg_prepared_statement_cache:
67
68Prepared Statement Cache
69--------------------------
70
71The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()``
72for all statements. The prepared statement objects are cached after
73construction which appears to grant a 10% or more performance improvement for
74statement invocation. The cache is on a per-DBAPI connection basis, which
75means that the primary storage for prepared statements is within DBAPI
76connections pooled within the connection pool. The size of this cache
77defaults to 100 statements per DBAPI connection and may be adjusted using the
78``prepared_statement_cache_size`` DBAPI argument (note that while this argument
79is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the
80asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
81argument)::
82
83
84 engine = create_async_engine(
85 "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500"
86 )
87
88To disable the prepared statement cache, use a value of zero::
89
90 engine = create_async_engine(
91 "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0"
92 )
93
94.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
95
96
97.. warning:: The ``asyncpg`` database driver necessarily uses caches for
98 PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes
99 such as ``ENUM`` objects are changed via DDL operations. Additionally,
100 prepared statements themselves which are optionally cached by SQLAlchemy's
101 driver as described above may also become "stale" when DDL has been emitted
102 to the PostgreSQL database which modifies the tables or other objects
103 involved in a particular prepared statement.
104
105 The SQLAlchemy asyncpg dialect will invalidate these caches within its local
106 process when statements that represent DDL are emitted on a local
107 connection, but this is only controllable within a single Python process /
108 database engine. If DDL changes are made from other database engines
109 and/or processes, a running application may encounter asyncpg exceptions
110 ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup
111 failed for type <oid>")`` if it refers to pooled database connections which
112 operated upon the previous structures. The SQLAlchemy asyncpg dialect will
113 recover from these error cases when the driver raises these exceptions by
114 clearing its internal caches as well as those of the asyncpg driver in
115 response to them, but cannot prevent them from being raised in the first
116 place if the cached prepared statement or asyncpg type caches have gone
117 stale, nor can it retry the statement as the PostgreSQL transaction is
118 invalidated when these errors occur.
119
120.. _asyncpg_prepared_statement_name:
121
122Prepared Statement Name with PGBouncer
123--------------------------------------
124
125By default, asyncpg enumerates prepared statements in numeric order, which
126can lead to errors if a name has already been taken for another prepared
127statement. This issue can arise if your application uses database proxies
128such as PgBouncer to handle connections. One possible workaround is to
129use dynamic prepared statement names, which asyncpg now supports through
130an optional ``name`` value for the statement name. This allows you to
131generate your own unique names that won't conflict with existing ones.
132To achieve this, you can provide a function that will be called every time
133a prepared statement is prepared::
134
135 from uuid import uuid4
136
137 engine = create_async_engine(
138 "postgresql+asyncpg://user:pass@somepgbouncer/dbname",
139 poolclass=NullPool,
140 connect_args={
141 "prepared_statement_name_func": lambda: f"__asyncpg_{uuid4()}__",
142 },
143 )
144
145.. seealso::
146
147 https://github.com/MagicStack/asyncpg/issues/837
148
149 https://github.com/sqlalchemy/sqlalchemy/issues/6467
150
151.. warning:: When using PGBouncer, to prevent a buildup of useless prepared statements in
152 your application, it's important to use the :class:`.NullPool` pool
153 class, and to configure PgBouncer to use `DISCARD <https://www.postgresql.org/docs/current/sql-discard.html>`_
154 when returning connections. The DISCARD command is used to release resources held by the db connection,
155 including prepared statements. Without proper setup, prepared statements can
156 accumulate quickly and cause performance issues.
157
158Disabling the PostgreSQL JIT to improve ENUM datatype handling
159---------------------------------------------------------------
160
161Asyncpg has an `issue <https://github.com/MagicStack/asyncpg/issues/727>`_ when
162using PostgreSQL ENUM datatypes, where upon the creation of new database
163connections, an expensive query may be emitted in order to retrieve metadata
164regarding custom types which has been shown to negatively affect performance.
165To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the
166client using this setting passed to :func:`_asyncio.create_async_engine`::
167
168 engine = create_async_engine(
169 "postgresql+asyncpg://user:password@localhost/tmp",
170 connect_args={"server_settings": {"jit": "off"}},
171 )
172
173.. seealso::
174
175 https://github.com/MagicStack/asyncpg/issues/727
176
177""" # noqa
178
179from __future__ import annotations
180
181from collections import deque
182import decimal
183import json as _py_json
184import re
185import time
186
187from . import json
188from . import ranges
189from .array import ARRAY as PGARRAY
190from .base import _DECIMAL_TYPES
191from .base import _FLOAT_TYPES
192from .base import _INT_TYPES
193from .base import ENUM
194from .base import INTERVAL
195from .base import OID
196from .base import PGCompiler
197from .base import PGDialect
198from .base import PGExecutionContext
199from .base import PGIdentifierPreparer
200from .base import REGCLASS
201from .base import REGCONFIG
202from .types import BIT
203from .types import BYTEA
204from .types import CITEXT
205from ... import exc
206from ... import pool
207from ... import util
208from ...connectors.asyncio import AsyncAdapt_terminate
209from ...engine import AdaptedConnection
210from ...engine import processors
211from ...sql import sqltypes
212from ...util.concurrency import asyncio
213from ...util.concurrency import await_fallback
214from ...util.concurrency import await_only
215
216
217class AsyncpgARRAY(PGARRAY):
218 render_bind_cast = True
219
220
221class AsyncpgString(sqltypes.String):
222 render_bind_cast = True
223
224
225class AsyncpgREGCONFIG(REGCONFIG):
226 render_bind_cast = True
227
228
229class AsyncpgTime(sqltypes.Time):
230 render_bind_cast = True
231
232
233class AsyncpgBit(BIT):
234 render_bind_cast = True
235
236
237class AsyncpgByteA(BYTEA):
238 render_bind_cast = True
239
240
241class AsyncpgDate(sqltypes.Date):
242 render_bind_cast = True
243
244
245class AsyncpgDateTime(sqltypes.DateTime):
246 render_bind_cast = True
247
248
249class AsyncpgBoolean(sqltypes.Boolean):
250 render_bind_cast = True
251
252
253class AsyncPgInterval(INTERVAL):
254 render_bind_cast = True
255
256 @classmethod
257 def adapt_emulated_to_native(cls, interval, **kw):
258 return AsyncPgInterval(precision=interval.second_precision)
259
260
261class AsyncPgEnum(ENUM):
262 render_bind_cast = True
263
264
265class AsyncpgInteger(sqltypes.Integer):
266 render_bind_cast = True
267
268
269class AsyncpgSmallInteger(sqltypes.SmallInteger):
270 render_bind_cast = True
271
272
273class AsyncpgBigInteger(sqltypes.BigInteger):
274 render_bind_cast = True
275
276
277class AsyncpgJSON(json.JSON):
278 def result_processor(self, dialect, coltype):
279 return None
280
281
282class AsyncpgJSONB(json.JSONB):
283 def result_processor(self, dialect, coltype):
284 return None
285
286
287class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
288 pass
289
290
291class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
292 __visit_name__ = "json_int_index"
293
294 render_bind_cast = True
295
296
297class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
298 __visit_name__ = "json_str_index"
299
300 render_bind_cast = True
301
302
303class AsyncpgJSONPathType(json.JSONPathType):
304 def bind_processor(self, dialect):
305 def process(value):
306 if isinstance(value, str):
307 # If it's already a string assume that it's in json path
308 # format. This allows using cast with json paths literals
309 return value
310 elif value:
311 tokens = [str(elem) for elem in value]
312 return tokens
313 else:
314 return []
315
316 return process
317
318
319class AsyncpgNumeric(sqltypes.Numeric):
320 render_bind_cast = True
321
322 def bind_processor(self, dialect):
323 return None
324
325 def result_processor(self, dialect, coltype):
326 if self.asdecimal:
327 if coltype in _FLOAT_TYPES:
328 return processors.to_decimal_processor_factory(
329 decimal.Decimal, self._effective_decimal_return_scale
330 )
331 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
332 # pg8000 returns Decimal natively for 1700
333 return None
334 else:
335 raise exc.InvalidRequestError(
336 "Unknown PG numeric type: %d" % coltype
337 )
338 else:
339 if coltype in _FLOAT_TYPES:
340 # pg8000 returns float natively for 701
341 return None
342 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
343 return processors.to_float
344 else:
345 raise exc.InvalidRequestError(
346 "Unknown PG numeric type: %d" % coltype
347 )
348
349
350class AsyncpgFloat(AsyncpgNumeric, sqltypes.Float):
351 __visit_name__ = "float"
352 render_bind_cast = True
353
354
355class AsyncpgREGCLASS(REGCLASS):
356 render_bind_cast = True
357
358
359class AsyncpgOID(OID):
360 render_bind_cast = True
361
362
363class AsyncpgCHAR(sqltypes.CHAR):
364 render_bind_cast = True
365
366
367class _AsyncpgRange(ranges.AbstractSingleRangeImpl):
368 def bind_processor(self, dialect):
369 asyncpg_Range = dialect.dbapi.asyncpg.Range
370
371 def to_range(value):
372 if isinstance(value, ranges.Range):
373 value = asyncpg_Range(
374 value.lower,
375 value.upper,
376 lower_inc=value.bounds[0] == "[",
377 upper_inc=value.bounds[1] == "]",
378 empty=value.empty,
379 )
380 return value
381
382 return to_range
383
384 def result_processor(self, dialect, coltype):
385 def to_range(value):
386 if value is not None:
387 empty = value.isempty
388 value = ranges.Range(
389 value.lower,
390 value.upper,
391 bounds=f"{'[' if empty or value.lower_inc else '('}" # type: ignore # noqa: E501
392 f"{']' if not empty and value.upper_inc else ')'}",
393 empty=empty,
394 )
395 return value
396
397 return to_range
398
399
400class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
401 def bind_processor(self, dialect):
402 asyncpg_Range = dialect.dbapi.asyncpg.Range
403
404 NoneType = type(None)
405
406 def to_range(value):
407 if isinstance(value, (str, NoneType)):
408 return value
409
410 def to_range(value):
411 if isinstance(value, ranges.Range):
412 value = asyncpg_Range(
413 value.lower,
414 value.upper,
415 lower_inc=value.bounds[0] == "[",
416 upper_inc=value.bounds[1] == "]",
417 empty=value.empty,
418 )
419 return value
420
421 return [to_range(element) for element in value]
422
423 return to_range
424
425 def result_processor(self, dialect, coltype):
426 def to_range_array(value):
427 def to_range(rvalue):
428 if rvalue is not None:
429 empty = rvalue.isempty
430 rvalue = ranges.Range(
431 rvalue.lower,
432 rvalue.upper,
433 bounds=f"{'[' if empty or rvalue.lower_inc else '('}" # type: ignore # noqa: E501
434 f"{']' if not empty and rvalue.upper_inc else ')'}",
435 empty=empty,
436 )
437 return rvalue
438
439 if value is not None:
440 value = ranges.MultiRange(to_range(elem) for elem in value)
441
442 return value
443
444 return to_range_array
445
446
447class PGExecutionContext_asyncpg(PGExecutionContext):
448 def handle_dbapi_exception(self, e):
449 if isinstance(
450 e,
451 (
452 self.dialect.dbapi.InvalidCachedStatementError,
453 self.dialect.dbapi.InternalServerError,
454 ),
455 ):
456 self.dialect._invalidate_schema_cache()
457
458 def pre_exec(self):
459 if self.isddl:
460 self.dialect._invalidate_schema_cache()
461
462 self.cursor._invalidate_schema_cache_asof = (
463 self.dialect._invalidate_schema_cache_asof
464 )
465
466 if not self.compiled:
467 return
468
469 def create_server_side_cursor(self):
470 return self._dbapi_connection.cursor(server_side=True)
471
472
473class PGCompiler_asyncpg(PGCompiler):
474 pass
475
476
477class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer):
478 pass
479
480
481class AsyncAdapt_asyncpg_cursor:
482 __slots__ = (
483 "_adapt_connection",
484 "_connection",
485 "_rows",
486 "description",
487 "arraysize",
488 "rowcount",
489 "_cursor",
490 "_invalidate_schema_cache_asof",
491 )
492
493 server_side = False
494 _awaitable_cursor_close: bool = False
495
496 def __init__(self, adapt_connection):
497 self._adapt_connection = adapt_connection
498 self._connection = adapt_connection._connection
499 self._rows = deque()
500 self._cursor = None
501 self.description = None
502 self.arraysize = 1
503 self.rowcount = -1
504 self._invalidate_schema_cache_asof = 0
505
506 async def _async_soft_close(self) -> None:
507 return
508
509 def close(self):
510 self._rows.clear()
511
512 def _handle_exception(self, error):
513 self._adapt_connection._handle_exception(error)
514
515 async def _prepare_and_execute(self, operation, parameters):
516 adapt_connection = self._adapt_connection
517
518 async with adapt_connection._execute_mutex:
519 if not adapt_connection._started:
520 await adapt_connection._start_transaction()
521
522 if parameters is None:
523 parameters = ()
524
525 try:
526 prepared_stmt, attributes = await adapt_connection._prepare(
527 operation, self._invalidate_schema_cache_asof
528 )
529
530 if attributes:
531 self.description = [
532 (
533 attr.name,
534 attr.type.oid,
535 None,
536 None,
537 None,
538 None,
539 None,
540 )
541 for attr in attributes
542 ]
543 else:
544 self.description = None
545
546 if self.server_side:
547 self._cursor = await prepared_stmt.cursor(*parameters)
548 self.rowcount = -1
549 else:
550 self._rows = deque(await prepared_stmt.fetch(*parameters))
551 status = prepared_stmt.get_statusmsg()
552
553 reg = re.match(
554 r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)",
555 status or "",
556 )
557 if reg:
558 self.rowcount = int(reg.group(1))
559 else:
560 self.rowcount = -1
561
562 except Exception as error:
563 self._handle_exception(error)
564
565 async def _executemany(self, operation, seq_of_parameters):
566 adapt_connection = self._adapt_connection
567
568 self.description = None
569 async with adapt_connection._execute_mutex:
570 await adapt_connection._check_type_cache_invalidation(
571 self._invalidate_schema_cache_asof
572 )
573
574 if not adapt_connection._started:
575 await adapt_connection._start_transaction()
576
577 try:
578 return await self._connection.executemany(
579 operation, seq_of_parameters
580 )
581 except Exception as error:
582 self._handle_exception(error)
583
584 def execute(self, operation, parameters=None):
585 self._adapt_connection.await_(
586 self._prepare_and_execute(operation, parameters)
587 )
588
589 def executemany(self, operation, seq_of_parameters):
590 return self._adapt_connection.await_(
591 self._executemany(operation, seq_of_parameters)
592 )
593
594 def setinputsizes(self, *inputsizes):
595 raise NotImplementedError()
596
597 def __iter__(self):
598 while self._rows:
599 yield self._rows.popleft()
600
601 def fetchone(self):
602 if self._rows:
603 return self._rows.popleft()
604 else:
605 return None
606
607 def fetchmany(self, size=None):
608 if size is None:
609 size = self.arraysize
610
611 rr = self._rows
612 return [rr.popleft() for _ in range(min(size, len(rr)))]
613
614 def fetchall(self):
615 retval = list(self._rows)
616 self._rows.clear()
617 return retval
618
619
620class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
621 server_side = True
622 __slots__ = ("_rowbuffer",)
623
624 def __init__(self, adapt_connection):
625 super().__init__(adapt_connection)
626 self._rowbuffer = deque()
627
628 def close(self):
629 self._cursor = None
630 self._rowbuffer.clear()
631
632 def _buffer_rows(self):
633 assert self._cursor is not None
634 new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
635 self._rowbuffer.extend(new_rows)
636
637 def __aiter__(self):
638 return self
639
640 async def __anext__(self):
641 while True:
642 while self._rowbuffer:
643 yield self._rowbuffer.popleft()
644
645 self._buffer_rows()
646 if not self._rowbuffer:
647 break
648
649 def fetchone(self):
650 if not self._rowbuffer:
651 self._buffer_rows()
652 if not self._rowbuffer:
653 return None
654 return self._rowbuffer.popleft()
655
656 def fetchmany(self, size=None):
657 if size is None:
658 return self.fetchall()
659
660 if not self._rowbuffer:
661 self._buffer_rows()
662
663 assert self._cursor is not None
664 rb = self._rowbuffer
665 lb = len(rb)
666 if size > lb:
667 rb.extend(
668 self._adapt_connection.await_(self._cursor.fetch(size - lb))
669 )
670
671 return [rb.popleft() for _ in range(min(size, len(rb)))]
672
673 def fetchall(self):
674 ret = list(self._rowbuffer)
675 ret.extend(self._adapt_connection.await_(self._all()))
676 self._rowbuffer.clear()
677 return ret
678
679 async def _all(self):
680 rows = []
681
682 # TODO: looks like we have to hand-roll some kind of batching here.
683 # hardcoding for the moment but this should be improved.
684 while True:
685 batch = await self._cursor.fetch(1000)
686 if batch:
687 rows.extend(batch)
688 continue
689 else:
690 break
691 return rows
692
693 def executemany(self, operation, seq_of_parameters):
694 raise NotImplementedError(
695 "server side cursor doesn't support executemany yet"
696 )
697
698
699class AsyncAdapt_asyncpg_connection(AsyncAdapt_terminate, AdaptedConnection):
700 __slots__ = (
701 "dbapi",
702 "isolation_level",
703 "_isolation_setting",
704 "readonly",
705 "deferrable",
706 "_transaction",
707 "_started",
708 "_prepared_statement_cache",
709 "_prepared_statement_name_func",
710 "_invalidate_schema_cache_asof",
711 "_execute_mutex",
712 )
713
714 await_ = staticmethod(await_only)
715
716 def __init__(
717 self,
718 dbapi,
719 connection,
720 prepared_statement_cache_size=100,
721 prepared_statement_name_func=None,
722 ):
723 self.dbapi = dbapi
724 self._connection = connection
725 self.isolation_level = self._isolation_setting = None
726 self.readonly = False
727 self.deferrable = False
728 self._transaction = None
729 self._started = False
730 self._invalidate_schema_cache_asof = time.time()
731 self._execute_mutex = asyncio.Lock()
732
733 if prepared_statement_cache_size:
734 self._prepared_statement_cache = util.LRUCache(
735 prepared_statement_cache_size
736 )
737 else:
738 self._prepared_statement_cache = None
739
740 if prepared_statement_name_func:
741 self._prepared_statement_name_func = prepared_statement_name_func
742 else:
743 self._prepared_statement_name_func = self._default_name_func
744
745 async def _check_type_cache_invalidation(self, invalidate_timestamp):
746 if invalidate_timestamp > self._invalidate_schema_cache_asof:
747 await self._connection.reload_schema_state()
748 self._invalidate_schema_cache_asof = invalidate_timestamp
749
750 async def _prepare(self, operation, invalidate_timestamp):
751 await self._check_type_cache_invalidation(invalidate_timestamp)
752
753 cache = self._prepared_statement_cache
754 if cache is None:
755 prepared_stmt = await self._connection.prepare(
756 operation, name=self._prepared_statement_name_func()
757 )
758 attributes = prepared_stmt.get_attributes()
759 return prepared_stmt, attributes
760
761 # asyncpg uses a type cache for the "attributes" which seems to go
762 # stale independently of the PreparedStatement itself, so place that
763 # collection in the cache as well.
764 if operation in cache:
765 prepared_stmt, attributes, cached_timestamp = cache[operation]
766
767 # preparedstatements themselves also go stale for certain DDL
768 # changes such as size of a VARCHAR changing, so there is also
769 # a cross-connection invalidation timestamp
770 if cached_timestamp > invalidate_timestamp:
771 return prepared_stmt, attributes
772
773 prepared_stmt = await self._connection.prepare(
774 operation, name=self._prepared_statement_name_func()
775 )
776 attributes = prepared_stmt.get_attributes()
777 cache[operation] = (prepared_stmt, attributes, time.time())
778
779 return prepared_stmt, attributes
780
781 def _handle_exception(self, error):
782 if self._connection.is_closed():
783 self._transaction = None
784 self._started = False
785
786 if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
787 exception_mapping = self.dbapi._asyncpg_error_translate
788
789 for super_ in type(error).__mro__:
790 if super_ in exception_mapping:
791 translated_error = exception_mapping[super_](
792 "%s: %s" % (type(error), error)
793 )
794 translated_error.pgcode = translated_error.sqlstate = (
795 getattr(error, "sqlstate", None)
796 )
797 raise translated_error from error
798 else:
799 raise error
800 else:
801 raise error
802
803 @property
804 def autocommit(self):
805 return self.isolation_level == "autocommit"
806
807 @autocommit.setter
808 def autocommit(self, value):
809 if value:
810 self.isolation_level = "autocommit"
811 else:
812 self.isolation_level = self._isolation_setting
813
814 def ping(self):
815 try:
816 _ = self.await_(self._async_ping())
817 except Exception as error:
818 self._handle_exception(error)
819
820 async def _async_ping(self):
821 if self._transaction is None and self.isolation_level != "autocommit":
822 # create a transaction explicitly to support pgbouncer
823 # transaction mode. See #10226
824 tr = self._connection.transaction()
825 await tr.start()
826 try:
827 await self._connection.fetchrow(";")
828 finally:
829 await tr.rollback()
830 else:
831 await self._connection.fetchrow(";")
832
833 def set_isolation_level(self, level):
834 if self._started:
835 self.rollback()
836 self.isolation_level = self._isolation_setting = level
837
838 async def _start_transaction(self):
839 if self.isolation_level == "autocommit":
840 return
841
842 try:
843 self._transaction = self._connection.transaction(
844 isolation=self.isolation_level,
845 readonly=self.readonly,
846 deferrable=self.deferrable,
847 )
848 await self._transaction.start()
849 except Exception as error:
850 self._handle_exception(error)
851 else:
852 self._started = True
853
854 def cursor(self, server_side=False):
855 if server_side:
856 return AsyncAdapt_asyncpg_ss_cursor(self)
857 else:
858 return AsyncAdapt_asyncpg_cursor(self)
859
860 async def _rollback_and_discard(self):
861 try:
862 await self._transaction.rollback()
863 finally:
864 # if asyncpg .rollback() was actually called, then whether or
865 # not it raised or succeeded, the transation is done, discard it
866 self._transaction = None
867 self._started = False
868
869 async def _commit_and_discard(self):
870 try:
871 await self._transaction.commit()
872 finally:
873 # if asyncpg .commit() was actually called, then whether or
874 # not it raised or succeeded, the transation is done, discard it
875 self._transaction = None
876 self._started = False
877
878 def rollback(self):
879 if self._started:
880 try:
881 self.await_(self._rollback_and_discard())
882 self._transaction = None
883 self._started = False
884 except Exception as error:
885 # don't dereference asyncpg transaction if we didn't
886 # actually try to call rollback() on it
887 self._handle_exception(error)
888
889 def commit(self):
890 if self._started:
891 try:
892 self.await_(self._commit_and_discard())
893 self._transaction = None
894 self._started = False
895 except Exception as error:
896 # don't dereference asyncpg transaction if we didn't
897 # actually try to call commit() on it
898 self._handle_exception(error)
899
900 def close(self):
901 self.rollback()
902
903 self.await_(self._connection.close())
904
905 def _terminate_handled_exceptions(self):
906 return super()._terminate_handled_exceptions() + (
907 self.dbapi.asyncpg.PostgresError,
908 )
909
910 async def _terminate_graceful_close(self) -> None:
911 # timeout added in asyncpg 0.14.0 December 2017
912 await self._connection.close(timeout=2)
913 self._started = False
914
915 def _terminate_force_close(self) -> None:
916 self._connection.terminate()
917 self._started = False
918
919 @staticmethod
920 def _default_name_func():
921 return None
922
923
924class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
925 __slots__ = ()
926
927 await_ = staticmethod(await_fallback)
928
929
930class AsyncAdapt_asyncpg_dbapi:
931 def __init__(self, asyncpg):
932 self.asyncpg = asyncpg
933 self.paramstyle = "numeric_dollar"
934
935 def connect(self, *arg, **kw):
936 async_fallback = kw.pop("async_fallback", False)
937 creator_fn = kw.pop("async_creator_fn", self.asyncpg.connect)
938 prepared_statement_cache_size = kw.pop(
939 "prepared_statement_cache_size", 100
940 )
941 prepared_statement_name_func = kw.pop(
942 "prepared_statement_name_func", None
943 )
944
945 if util.asbool(async_fallback):
946 return AsyncAdaptFallback_asyncpg_connection(
947 self,
948 await_fallback(creator_fn(*arg, **kw)),
949 prepared_statement_cache_size=prepared_statement_cache_size,
950 prepared_statement_name_func=prepared_statement_name_func,
951 )
952 else:
953 return AsyncAdapt_asyncpg_connection(
954 self,
955 await_only(creator_fn(*arg, **kw)),
956 prepared_statement_cache_size=prepared_statement_cache_size,
957 prepared_statement_name_func=prepared_statement_name_func,
958 )
959
960 class Error(Exception):
961 pass
962
963 class Warning(Exception): # noqa
964 pass
965
966 class InterfaceError(Error):
967 pass
968
969 class DatabaseError(Error):
970 pass
971
972 class InternalError(DatabaseError):
973 pass
974
975 class OperationalError(DatabaseError):
976 pass
977
978 class ProgrammingError(DatabaseError):
979 pass
980
981 class IntegrityError(DatabaseError):
982 pass
983
984 class DataError(DatabaseError):
985 pass
986
987 class NotSupportedError(DatabaseError):
988 pass
989
990 class InternalServerError(InternalError):
991 pass
992
993 class InternalClientError(InternalError):
994 pass
995
996 class InvalidCachedStatementError(NotSupportedError):
997 def __init__(self, message):
998 super().__init__(
999 message + " (SQLAlchemy asyncpg dialect will now invalidate "
1000 "all prepared caches in response to this exception)",
1001 )
1002
1003 # pep-249 datatype placeholders. As of SQLAlchemy 2.0 these aren't
1004 # used, however the test suite looks for these in a few cases.
1005 STRING = util.symbol("STRING")
1006 NUMBER = util.symbol("NUMBER")
1007 DATETIME = util.symbol("DATETIME")
1008
1009 @util.memoized_property
1010 def _asyncpg_error_translate(self):
1011 import asyncpg
1012
1013 return {
1014 asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501
1015 asyncpg.exceptions.PostgresError: self.Error,
1016 asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError,
1017 asyncpg.exceptions.InterfaceError: self.InterfaceError,
1018 asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501
1019 asyncpg.exceptions.InternalServerError: self.InternalServerError,
1020 asyncpg.exceptions.InternalClientError: self.InternalClientError,
1021 }
1022
1023 def Binary(self, value):
1024 return value
1025
1026
1027class PGDialect_asyncpg(PGDialect):
1028 driver = "asyncpg"
1029 supports_statement_cache = True
1030
1031 supports_server_side_cursors = True
1032
1033 render_bind_cast = True
1034 has_terminate = True
1035
1036 default_paramstyle = "numeric_dollar"
1037 supports_sane_multi_rowcount = False
1038 execution_ctx_cls = PGExecutionContext_asyncpg
1039 statement_compiler = PGCompiler_asyncpg
1040 preparer = PGIdentifierPreparer_asyncpg
1041
1042 colspecs = util.update_copy(
1043 PGDialect.colspecs,
1044 {
1045 sqltypes.String: AsyncpgString,
1046 sqltypes.ARRAY: AsyncpgARRAY,
1047 BIT: AsyncpgBit,
1048 CITEXT: CITEXT,
1049 REGCONFIG: AsyncpgREGCONFIG,
1050 sqltypes.Time: AsyncpgTime,
1051 sqltypes.Date: AsyncpgDate,
1052 sqltypes.DateTime: AsyncpgDateTime,
1053 sqltypes.Interval: AsyncPgInterval,
1054 INTERVAL: AsyncPgInterval,
1055 sqltypes.Boolean: AsyncpgBoolean,
1056 sqltypes.Integer: AsyncpgInteger,
1057 sqltypes.SmallInteger: AsyncpgSmallInteger,
1058 sqltypes.BigInteger: AsyncpgBigInteger,
1059 sqltypes.Numeric: AsyncpgNumeric,
1060 sqltypes.Float: AsyncpgFloat,
1061 sqltypes.JSON: AsyncpgJSON,
1062 sqltypes.LargeBinary: AsyncpgByteA,
1063 json.JSONB: AsyncpgJSONB,
1064 sqltypes.JSON.JSONPathType: AsyncpgJSONPathType,
1065 sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType,
1066 sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType,
1067 sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType,
1068 sqltypes.Enum: AsyncPgEnum,
1069 OID: AsyncpgOID,
1070 REGCLASS: AsyncpgREGCLASS,
1071 sqltypes.CHAR: AsyncpgCHAR,
1072 ranges.AbstractSingleRange: _AsyncpgRange,
1073 ranges.AbstractMultiRange: _AsyncpgMultiRange,
1074 },
1075 )
1076 is_async = True
1077 _invalidate_schema_cache_asof = 0
1078
1079 def _invalidate_schema_cache(self):
1080 self._invalidate_schema_cache_asof = time.time()
1081
1082 @util.memoized_property
1083 def _dbapi_version(self):
1084 if self.dbapi and hasattr(self.dbapi, "__version__"):
1085 return tuple(
1086 [
1087 int(x)
1088 for x in re.findall(
1089 r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
1090 )
1091 ]
1092 )
1093 else:
1094 return (99, 99, 99)
1095
1096 @classmethod
1097 def import_dbapi(cls):
1098 return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg"))
1099
1100 @util.memoized_property
1101 def _isolation_lookup(self):
1102 return {
1103 "AUTOCOMMIT": "autocommit",
1104 "READ COMMITTED": "read_committed",
1105 "REPEATABLE READ": "repeatable_read",
1106 "SERIALIZABLE": "serializable",
1107 }
1108
1109 def get_isolation_level_values(self, dbapi_connection):
1110 return list(self._isolation_lookup)
1111
1112 def set_isolation_level(self, dbapi_connection, level):
1113 dbapi_connection.set_isolation_level(self._isolation_lookup[level])
1114
1115 def detect_autocommit_setting(self, dbapi_conn) -> bool:
1116 return bool(dbapi_conn.autocommit)
1117
1118 def set_readonly(self, connection, value):
1119 connection.readonly = value
1120
1121 def get_readonly(self, connection):
1122 return connection.readonly
1123
1124 def set_deferrable(self, connection, value):
1125 connection.deferrable = value
1126
1127 def get_deferrable(self, connection):
1128 return connection.deferrable
1129
1130 def do_terminate(self, dbapi_connection) -> None:
1131 dbapi_connection.terminate()
1132
1133 def create_connect_args(self, url):
1134 opts = url.translate_connect_args(username="user")
1135 multihosts, multiports = self._split_multihost_from_url(url)
1136
1137 opts.update(url.query)
1138
1139 if multihosts:
1140 assert multiports
1141 if len(multihosts) == 1:
1142 opts["host"] = multihosts[0]
1143 if multiports[0] is not None:
1144 opts["port"] = multiports[0]
1145 elif not all(multihosts):
1146 raise exc.ArgumentError(
1147 "All hosts are required to be present"
1148 " for asyncpg multiple host URL"
1149 )
1150 elif not all(multiports):
1151 raise exc.ArgumentError(
1152 "All ports are required to be present"
1153 " for asyncpg multiple host URL"
1154 )
1155 else:
1156 opts["host"] = list(multihosts)
1157 opts["port"] = list(multiports)
1158 else:
1159 util.coerce_kw_type(opts, "port", int)
1160 util.coerce_kw_type(opts, "prepared_statement_cache_size", int)
1161 return ([], opts)
1162
1163 def do_ping(self, dbapi_connection):
1164 dbapi_connection.ping()
1165 return True
1166
1167 @classmethod
1168 def get_pool_class(cls, url):
1169 async_fallback = url.query.get("async_fallback", False)
1170
1171 if util.asbool(async_fallback):
1172 return pool.FallbackAsyncAdaptedQueuePool
1173 else:
1174 return pool.AsyncAdaptedQueuePool
1175
1176 def is_disconnect(self, e, connection, cursor):
1177 if connection:
1178 return connection._connection.is_closed()
1179 else:
1180 return isinstance(
1181 e, self.dbapi.InterfaceError
1182 ) and "connection is closed" in str(e)
1183
1184 async def setup_asyncpg_json_codec(self, conn):
1185 """set up JSON codec for asyncpg.
1186
1187 This occurs for all new connections and
1188 can be overridden by third party dialects.
1189
1190 .. versionadded:: 1.4.27
1191
1192 """
1193
1194 asyncpg_connection = conn._connection
1195 deserializer = self._json_deserializer or _py_json.loads
1196
1197 def _json_decoder(bin_value):
1198 return deserializer(bin_value.decode())
1199
1200 await asyncpg_connection.set_type_codec(
1201 "json",
1202 encoder=str.encode,
1203 decoder=_json_decoder,
1204 schema="pg_catalog",
1205 format="binary",
1206 )
1207
1208 async def setup_asyncpg_jsonb_codec(self, conn):
1209 """set up JSONB codec for asyncpg.
1210
1211 This occurs for all new connections and
1212 can be overridden by third party dialects.
1213
1214 .. versionadded:: 1.4.27
1215
1216 """
1217
1218 asyncpg_connection = conn._connection
1219
1220 def _jsonb_encoder(str_value):
1221 # \x01 is the prefix for jsonb used by PostgreSQL.
1222 # asyncpg requires it when format='binary'
1223 return b"\x01" + str_value.encode()
1224
1225 deserializer = self._json_deserializer or _py_json.loads
1226
1227 def _jsonb_decoder(bin_value):
1228 # the byte is the \x01 prefix for jsonb used by PostgreSQL.
1229 # asyncpg returns it when format='binary'
1230 return deserializer(bin_value[1:].decode())
1231
1232 await asyncpg_connection.set_type_codec(
1233 "jsonb",
1234 encoder=_jsonb_encoder,
1235 decoder=_jsonb_decoder,
1236 schema="pg_catalog",
1237 format="binary",
1238 )
1239
1240 async def _disable_asyncpg_inet_codecs(self, conn):
1241 asyncpg_connection = conn._connection
1242
1243 await asyncpg_connection.set_type_codec(
1244 "inet",
1245 encoder=lambda s: s,
1246 decoder=lambda s: s,
1247 schema="pg_catalog",
1248 format="text",
1249 )
1250
1251 await asyncpg_connection.set_type_codec(
1252 "cidr",
1253 encoder=lambda s: s,
1254 decoder=lambda s: s,
1255 schema="pg_catalog",
1256 format="text",
1257 )
1258
1259 def on_connect(self):
1260 """on_connect for asyncpg
1261
1262 A major component of this for asyncpg is to set up type decoders at the
1263 asyncpg level.
1264
1265 See https://github.com/MagicStack/asyncpg/issues/623 for
1266 notes on JSON/JSONB implementation.
1267
1268 """
1269
1270 super_connect = super().on_connect()
1271
1272 def connect(conn):
1273 conn.await_(self.setup_asyncpg_json_codec(conn))
1274 conn.await_(self.setup_asyncpg_jsonb_codec(conn))
1275
1276 if self._native_inet_types is False:
1277 conn.await_(self._disable_asyncpg_inet_codecs(conn))
1278 if super_connect is not None:
1279 super_connect(conn)
1280
1281 return connect
1282
1283 def get_driver_connection(self, connection):
1284 return connection._connection
1285
1286
1287dialect = PGDialect_asyncpg