1# dialects/postgresql/asyncpg.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
3# file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9r"""
10.. dialect:: postgresql+asyncpg
11 :name: asyncpg
12 :dbapi: asyncpg
13 :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...]
14 :url: https://magicstack.github.io/asyncpg/
15
16The asyncpg dialect is SQLAlchemy's first Python asyncio dialect.
17
18Using a special asyncio mediation layer, the asyncpg dialect is usable
19as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
20extension package.
21
22This dialect should normally be used only with the
23:func:`_asyncio.create_async_engine` engine creation function::
24
25 from sqlalchemy.ext.asyncio import create_async_engine
26
27 engine = create_async_engine(
28 "postgresql+asyncpg://user:pass@hostname/dbname"
29 )
30
31.. versionadded:: 1.4
32
33.. note::
34
35 By default asyncpg does not decode the ``json`` and ``jsonb`` types and
36 returns them as strings. SQLAlchemy sets default type decoder for ``json``
37 and ``jsonb`` types using the python builtin ``json.loads`` function.
38 The json implementation used can be changed by setting the attribute
39 ``json_deserializer`` when creating the engine with
40 :func:`create_engine` or :func:`create_async_engine`.
41
42.. _asyncpg_multihost:
43
44Multihost Connections
45--------------------------
46
47The asyncpg dialect features support for multiple fallback hosts in the
48same way as that of the psycopg2 and psycopg dialects. The
49syntax is the same,
50using ``host=<host>:<port>`` combinations as additional query string arguments;
51however, there is no default port, so all hosts must have a complete port number
52present, otherwise an exception is raised::
53
54 engine = create_async_engine(
55 "postgresql+asyncpg://user:password@/dbname?host=HostA:5432&host=HostB:5432&host=HostC:5432"
56 )
57
58For complete background on this syntax, see :ref:`psycopg2_multi_host`.
59
60.. versionadded:: 2.0.18
61
62.. seealso::
63
64 :ref:`psycopg2_multi_host`
65
66.. _asyncpg_prepared_statement_cache:
67
68Prepared Statement Cache
69--------------------------
70
71The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()``
72for all statements. The prepared statement objects are cached after
73construction which appears to grant a 10% or more performance improvement for
74statement invocation. The cache is on a per-DBAPI connection basis, which
75means that the primary storage for prepared statements is within DBAPI
76connections pooled within the connection pool. The size of this cache
77defaults to 100 statements per DBAPI connection and may be adjusted using the
78``prepared_statement_cache_size`` DBAPI argument (note that while this argument
79is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the
80asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
81argument)::
82
83
84 engine = create_async_engine(
85 "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500"
86 )
87
88To disable the prepared statement cache, use a value of zero::
89
90 engine = create_async_engine(
91 "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0"
92 )
93
94.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
95
96
97.. warning:: The ``asyncpg`` database driver necessarily uses caches for
98 PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes
99 such as ``ENUM`` objects are changed via DDL operations. Additionally,
100 prepared statements themselves which are optionally cached by SQLAlchemy's
101 driver as described above may also become "stale" when DDL has been emitted
102 to the PostgreSQL database which modifies the tables or other objects
103 involved in a particular prepared statement.
104
105 The SQLAlchemy asyncpg dialect will invalidate these caches within its local
106 process when statements that represent DDL are emitted on a local
107 connection, but this is only controllable within a single Python process /
108 database engine. If DDL changes are made from other database engines
109 and/or processes, a running application may encounter asyncpg exceptions
110 ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup
111 failed for type <oid>")`` if it refers to pooled database connections which
112 operated upon the previous structures. The SQLAlchemy asyncpg dialect will
113 recover from these error cases when the driver raises these exceptions by
114 clearing its internal caches as well as those of the asyncpg driver in
115 response to them, but cannot prevent them from being raised in the first
116 place if the cached prepared statement or asyncpg type caches have gone
117 stale, nor can it retry the statement as the PostgreSQL transaction is
118 invalidated when these errors occur.
119
120.. _asyncpg_prepared_statement_name:
121
122Prepared Statement Name with PGBouncer
123--------------------------------------
124
125By default, asyncpg enumerates prepared statements in numeric order, which
126can lead to errors if a name has already been taken for another prepared
127statement. This issue can arise if your application uses database proxies
128such as PgBouncer to handle connections. One possible workaround is to
129use dynamic prepared statement names, which asyncpg now supports through
130an optional ``name`` value for the statement name. This allows you to
131generate your own unique names that won't conflict with existing ones.
132To achieve this, you can provide a function that will be called every time
133a prepared statement is prepared::
134
135 from uuid import uuid4
136
137 engine = create_async_engine(
138 "postgresql+asyncpg://user:pass@somepgbouncer/dbname",
139 poolclass=NullPool,
140 connect_args={
141 "prepared_statement_name_func": lambda: f"__asyncpg_{uuid4()}__",
142 },
143 )
144
145.. seealso::
146
147 https://github.com/MagicStack/asyncpg/issues/837
148
149 https://github.com/sqlalchemy/sqlalchemy/issues/6467
150
151.. warning:: When using PGBouncer, to prevent a buildup of useless prepared statements in
152 your application, it's important to use the :class:`.NullPool` pool
153 class, and to configure PgBouncer to use `DISCARD <https://www.postgresql.org/docs/current/sql-discard.html>`_
154 when returning connections. The DISCARD command is used to release resources held by the db connection,
155 including prepared statements. Without proper setup, prepared statements can
156 accumulate quickly and cause performance issues.
157
158Disabling the PostgreSQL JIT to improve ENUM datatype handling
159---------------------------------------------------------------
160
161Asyncpg has an `issue <https://github.com/MagicStack/asyncpg/issues/727>`_ when
162using PostgreSQL ENUM datatypes, where upon the creation of new database
163connections, an expensive query may be emitted in order to retrieve metadata
164regarding custom types which has been shown to negatively affect performance.
165To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the
166client using this setting passed to :func:`_asyncio.create_async_engine`::
167
168 engine = create_async_engine(
169 "postgresql+asyncpg://user:password@localhost/tmp",
170 connect_args={"server_settings": {"jit": "off"}},
171 )
172
173.. seealso::
174
175 https://github.com/MagicStack/asyncpg/issues/727
176
177""" # noqa
178
179from __future__ import annotations
180
181from collections import deque
182import decimal
183import json as _py_json
184import re
185import time
186
187from . import json
188from . import ranges
189from .array import ARRAY as PGARRAY
190from .base import _DECIMAL_TYPES
191from .base import _FLOAT_TYPES
192from .base import _INT_TYPES
193from .base import ENUM
194from .base import INTERVAL
195from .base import OID
196from .base import PGCompiler
197from .base import PGDialect
198from .base import PGExecutionContext
199from .base import PGIdentifierPreparer
200from .base import REGCLASS
201from .base import REGCONFIG
202from .types import BIT
203from .types import BYTEA
204from .types import CITEXT
205from ... import exc
206from ... import pool
207from ... import util
208from ...connectors.asyncio import AsyncAdapt_terminate
209from ...engine import AdaptedConnection
210from ...engine import processors
211from ...sql import sqltypes
212from ...util.concurrency import asyncio
213from ...util.concurrency import await_fallback
214from ...util.concurrency import await_only
215
216
217class AsyncpgARRAY(PGARRAY):
218 render_bind_cast = True
219
220
221class AsyncpgString(sqltypes.String):
222 render_bind_cast = True
223
224
225class AsyncpgREGCONFIG(REGCONFIG):
226 render_bind_cast = True
227
228
229class AsyncpgTime(sqltypes.Time):
230 render_bind_cast = True
231
232
233class AsyncpgBit(BIT):
234 render_bind_cast = True
235
236
237class AsyncpgByteA(BYTEA):
238 render_bind_cast = True
239
240
241class AsyncpgDate(sqltypes.Date):
242 render_bind_cast = True
243
244
245class AsyncpgDateTime(sqltypes.DateTime):
246 render_bind_cast = True
247
248
249class AsyncpgBoolean(sqltypes.Boolean):
250 render_bind_cast = True
251
252
253class AsyncPgInterval(INTERVAL):
254 render_bind_cast = True
255
256 @classmethod
257 def adapt_emulated_to_native(cls, interval, **kw):
258 return AsyncPgInterval(precision=interval.second_precision)
259
260
261class AsyncPgEnum(ENUM):
262 render_bind_cast = True
263
264
265class AsyncpgInteger(sqltypes.Integer):
266 render_bind_cast = True
267
268
269class AsyncpgSmallInteger(sqltypes.SmallInteger):
270 render_bind_cast = True
271
272
273class AsyncpgBigInteger(sqltypes.BigInteger):
274 render_bind_cast = True
275
276
277class AsyncpgJSON(json.JSON):
278 def result_processor(self, dialect, coltype):
279 return None
280
281
282class AsyncpgJSONB(json.JSONB):
283 def result_processor(self, dialect, coltype):
284 return None
285
286
287class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
288 pass
289
290
291class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
292 __visit_name__ = "json_int_index"
293
294 render_bind_cast = True
295
296
297class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
298 __visit_name__ = "json_str_index"
299
300 render_bind_cast = True
301
302
303class AsyncpgJSONPathType(json.JSONPathType):
304 def bind_processor(self, dialect):
305 def process(value):
306 if isinstance(value, str):
307 # If it's already a string assume that it's in json path
308 # format. This allows using cast with json paths literals
309 return value
310 elif value:
311 tokens = [str(elem) for elem in value]
312 return tokens
313 else:
314 return []
315
316 return process
317
318
319class AsyncpgNumeric(sqltypes.Numeric):
320 render_bind_cast = True
321
322 def bind_processor(self, dialect):
323 return None
324
325 def result_processor(self, dialect, coltype):
326 if self.asdecimal:
327 if coltype in _FLOAT_TYPES:
328 return processors.to_decimal_processor_factory(
329 decimal.Decimal, self._effective_decimal_return_scale
330 )
331 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
332 # pg8000 returns Decimal natively for 1700
333 return None
334 else:
335 raise exc.InvalidRequestError(
336 "Unknown PG numeric type: %d" % coltype
337 )
338 else:
339 if coltype in _FLOAT_TYPES:
340 # pg8000 returns float natively for 701
341 return None
342 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
343 return processors.to_float
344 else:
345 raise exc.InvalidRequestError(
346 "Unknown PG numeric type: %d" % coltype
347 )
348
349
350class AsyncpgFloat(AsyncpgNumeric, sqltypes.Float):
351 __visit_name__ = "float"
352 render_bind_cast = True
353
354
355class AsyncpgREGCLASS(REGCLASS):
356 render_bind_cast = True
357
358
359class AsyncpgOID(OID):
360 render_bind_cast = True
361
362
363class AsyncpgCHAR(sqltypes.CHAR):
364 render_bind_cast = True
365
366
367class _AsyncpgRange(ranges.AbstractSingleRangeImpl):
368 def bind_processor(self, dialect):
369 asyncpg_Range = dialect.dbapi.asyncpg.Range
370
371 def to_range(value):
372 if isinstance(value, ranges.Range):
373 value = asyncpg_Range(
374 value.lower,
375 value.upper,
376 lower_inc=value.bounds[0] == "[",
377 upper_inc=value.bounds[1] == "]",
378 empty=value.empty,
379 )
380 return value
381
382 return to_range
383
384 def result_processor(self, dialect, coltype):
385 def to_range(value):
386 if value is not None:
387 empty = value.isempty
388 value = ranges.Range(
389 value.lower,
390 value.upper,
391 bounds=f"{'[' if empty or value.lower_inc else '('}" # type: ignore # noqa: E501
392 f"{']' if not empty and value.upper_inc else ')'}",
393 empty=empty,
394 )
395 return value
396
397 return to_range
398
399
400class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
401 def bind_processor(self, dialect):
402 asyncpg_Range = dialect.dbapi.asyncpg.Range
403
404 NoneType = type(None)
405
406 def to_range(value):
407 if isinstance(value, (str, NoneType)):
408 return value
409
410 def to_range(value):
411 if isinstance(value, ranges.Range):
412 value = asyncpg_Range(
413 value.lower,
414 value.upper,
415 lower_inc=value.bounds[0] == "[",
416 upper_inc=value.bounds[1] == "]",
417 empty=value.empty,
418 )
419 return value
420
421 return [to_range(element) for element in value]
422
423 return to_range
424
425 def result_processor(self, dialect, coltype):
426 def to_range_array(value):
427 def to_range(rvalue):
428 if rvalue is not None:
429 empty = rvalue.isempty
430 rvalue = ranges.Range(
431 rvalue.lower,
432 rvalue.upper,
433 bounds=f"{'[' if empty or rvalue.lower_inc else '('}" # type: ignore # noqa: E501
434 f"{']' if not empty and rvalue.upper_inc else ')'}",
435 empty=empty,
436 )
437 return rvalue
438
439 if value is not None:
440 value = ranges.MultiRange(to_range(elem) for elem in value)
441
442 return value
443
444 return to_range_array
445
446
447class PGExecutionContext_asyncpg(PGExecutionContext):
448 def handle_dbapi_exception(self, e):
449 if isinstance(
450 e,
451 (
452 self.dialect.dbapi.InvalidCachedStatementError,
453 self.dialect.dbapi.InternalServerError,
454 ),
455 ):
456 self.dialect._invalidate_schema_cache()
457
458 def pre_exec(self):
459 if self.isddl:
460 self.dialect._invalidate_schema_cache()
461
462 self.cursor._invalidate_schema_cache_asof = (
463 self.dialect._invalidate_schema_cache_asof
464 )
465
466 if not self.compiled:
467 return
468
469 def create_server_side_cursor(self):
470 return self._dbapi_connection.cursor(server_side=True)
471
472
473class PGCompiler_asyncpg(PGCompiler):
474 pass
475
476
477class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer):
478 pass
479
480
481class AsyncAdapt_asyncpg_cursor:
482 __slots__ = (
483 "_adapt_connection",
484 "_connection",
485 "_rows",
486 "description",
487 "arraysize",
488 "rowcount",
489 "_cursor",
490 "_invalidate_schema_cache_asof",
491 )
492
493 server_side = False
494 _awaitable_cursor_close: bool = False
495
496 def __init__(self, adapt_connection):
497 self._adapt_connection = adapt_connection
498 self._connection = adapt_connection._connection
499 self._rows = deque()
500 self._cursor = None
501 self.description = None
502 self.arraysize = 1
503 self.rowcount = -1
504 self._invalidate_schema_cache_asof = 0
505
506 async def _async_soft_close(self) -> None:
507 return
508
509 def close(self):
510 self._rows.clear()
511
512 def _handle_exception(self, error):
513 self._adapt_connection._handle_exception(error)
514
515 async def _prepare_and_execute(self, operation, parameters):
516 adapt_connection = self._adapt_connection
517
518 async with adapt_connection._execute_mutex:
519 if not adapt_connection._started:
520 await adapt_connection._start_transaction()
521
522 if parameters is None:
523 parameters = ()
524
525 try:
526 prepared_stmt, attributes = await adapt_connection._prepare(
527 operation, self._invalidate_schema_cache_asof
528 )
529
530 if attributes:
531 self.description = [
532 (
533 attr.name,
534 attr.type.oid,
535 None,
536 None,
537 None,
538 None,
539 None,
540 )
541 for attr in attributes
542 ]
543 else:
544 self.description = None
545
546 if self.server_side:
547 self._cursor = await prepared_stmt.cursor(*parameters)
548 self.rowcount = -1
549 else:
550 self._rows = deque(await prepared_stmt.fetch(*parameters))
551 status = prepared_stmt.get_statusmsg()
552
553 reg = re.match(
554 r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)",
555 status or "",
556 )
557 if reg:
558 self.rowcount = int(reg.group(1))
559 else:
560 self.rowcount = -1
561
562 except Exception as error:
563 self._handle_exception(error)
564
565 async def _executemany(self, operation, seq_of_parameters):
566 adapt_connection = self._adapt_connection
567
568 self.description = None
569 async with adapt_connection._execute_mutex:
570 await adapt_connection._check_type_cache_invalidation(
571 self._invalidate_schema_cache_asof
572 )
573
574 if not adapt_connection._started:
575 await adapt_connection._start_transaction()
576
577 try:
578 return await self._connection.executemany(
579 operation, seq_of_parameters
580 )
581 except Exception as error:
582 self._handle_exception(error)
583
584 def execute(self, operation, parameters=None):
585 self._adapt_connection.await_(
586 self._prepare_and_execute(operation, parameters)
587 )
588
589 def executemany(self, operation, seq_of_parameters):
590 return self._adapt_connection.await_(
591 self._executemany(operation, seq_of_parameters)
592 )
593
594 def setinputsizes(self, *inputsizes):
595 raise NotImplementedError()
596
597 def __iter__(self):
598 while self._rows:
599 yield self._rows.popleft()
600
601 def fetchone(self):
602 if self._rows:
603 return self._rows.popleft()
604 else:
605 return None
606
607 def fetchmany(self, size=None):
608 if size is None:
609 size = self.arraysize
610
611 rr = self._rows
612 return [rr.popleft() for _ in range(min(size, len(rr)))]
613
614 def fetchall(self):
615 retval = list(self._rows)
616 self._rows.clear()
617 return retval
618
619
620class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
621 server_side = True
622 __slots__ = ("_rowbuffer",)
623
624 def __init__(self, adapt_connection):
625 super().__init__(adapt_connection)
626 self._rowbuffer = deque()
627
628 def close(self):
629 self._cursor = None
630 self._rowbuffer.clear()
631
632 def _buffer_rows(self):
633 assert self._cursor is not None
634 new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
635 self._rowbuffer.extend(new_rows)
636
637 def __aiter__(self):
638 return self
639
640 async def __anext__(self):
641 while True:
642 while self._rowbuffer:
643 yield self._rowbuffer.popleft()
644
645 self._buffer_rows()
646 if not self._rowbuffer:
647 break
648
649 def fetchone(self):
650 if not self._rowbuffer:
651 self._buffer_rows()
652 if not self._rowbuffer:
653 return None
654 return self._rowbuffer.popleft()
655
656 def fetchmany(self, size=None):
657 if size is None:
658 return self.fetchall()
659
660 if not self._rowbuffer:
661 self._buffer_rows()
662
663 assert self._cursor is not None
664 rb = self._rowbuffer
665 lb = len(rb)
666 if size > lb:
667 rb.extend(
668 self._adapt_connection.await_(self._cursor.fetch(size - lb))
669 )
670
671 return [rb.popleft() for _ in range(min(size, len(rb)))]
672
673 def fetchall(self):
674 ret = list(self._rowbuffer)
675 ret.extend(self._adapt_connection.await_(self._all()))
676 self._rowbuffer.clear()
677 return ret
678
679 async def _all(self):
680 rows = []
681
682 # TODO: looks like we have to hand-roll some kind of batching here.
683 # hardcoding for the moment but this should be improved.
684 while True:
685 batch = await self._cursor.fetch(1000)
686 if batch:
687 rows.extend(batch)
688 continue
689 else:
690 break
691 return rows
692
693 def executemany(self, operation, seq_of_parameters):
694 raise NotImplementedError(
695 "server side cursor doesn't support executemany yet"
696 )
697
698
699class AsyncAdapt_asyncpg_connection(AsyncAdapt_terminate, AdaptedConnection):
700 __slots__ = (
701 "dbapi",
702 "isolation_level",
703 "_isolation_setting",
704 "readonly",
705 "deferrable",
706 "_transaction",
707 "_started",
708 "_prepared_statement_cache",
709 "_prepared_statement_name_func",
710 "_invalidate_schema_cache_asof",
711 "_execute_mutex",
712 )
713
714 await_ = staticmethod(await_only)
715
716 def __init__(
717 self,
718 dbapi,
719 connection,
720 prepared_statement_cache_size=100,
721 prepared_statement_name_func=None,
722 ):
723 self.dbapi = dbapi
724 self._connection = connection
725 self.isolation_level = self._isolation_setting = None
726 self.readonly = False
727 self.deferrable = False
728 self._transaction = None
729 self._started = False
730 self._invalidate_schema_cache_asof = time.time()
731 self._execute_mutex = asyncio.Lock()
732
733 if prepared_statement_cache_size:
734 self._prepared_statement_cache = util.LRUCache(
735 prepared_statement_cache_size
736 )
737 else:
738 self._prepared_statement_cache = None
739
740 if prepared_statement_name_func:
741 self._prepared_statement_name_func = prepared_statement_name_func
742 else:
743 self._prepared_statement_name_func = self._default_name_func
744
745 async def _check_type_cache_invalidation(self, invalidate_timestamp):
746 if invalidate_timestamp > self._invalidate_schema_cache_asof:
747 await self._connection.reload_schema_state()
748 self._invalidate_schema_cache_asof = invalidate_timestamp
749
750 async def _prepare(self, operation, invalidate_timestamp):
751 await self._check_type_cache_invalidation(invalidate_timestamp)
752
753 cache = self._prepared_statement_cache
754 if cache is None:
755 prepared_stmt = await self._connection.prepare(
756 operation, name=self._prepared_statement_name_func()
757 )
758 attributes = prepared_stmt.get_attributes()
759 return prepared_stmt, attributes
760
761 # asyncpg uses a type cache for the "attributes" which seems to go
762 # stale independently of the PreparedStatement itself, so place that
763 # collection in the cache as well.
764 if operation in cache:
765 prepared_stmt, attributes, cached_timestamp = cache[operation]
766
767 # preparedstatements themselves also go stale for certain DDL
768 # changes such as size of a VARCHAR changing, so there is also
769 # a cross-connection invalidation timestamp
770 if cached_timestamp > invalidate_timestamp:
771 return prepared_stmt, attributes
772
773 prepared_stmt = await self._connection.prepare(
774 operation, name=self._prepared_statement_name_func()
775 )
776 attributes = prepared_stmt.get_attributes()
777 cache[operation] = (prepared_stmt, attributes, time.time())
778
779 return prepared_stmt, attributes
780
781 def _handle_exception(self, error):
782 if self._connection.is_closed():
783 self._transaction = None
784 self._started = False
785
786 if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
787 exception_mapping = self.dbapi._asyncpg_error_translate
788
789 for super_ in type(error).__mro__:
790 if super_ in exception_mapping:
791 translated_error = exception_mapping[super_](
792 "%s: %s" % (type(error), error)
793 )
794 translated_error.pgcode = translated_error.sqlstate = (
795 getattr(error, "sqlstate", None)
796 )
797 raise translated_error from error
798 else:
799 raise error
800 else:
801 raise error
802
803 @property
804 def autocommit(self):
805 return self.isolation_level == "autocommit"
806
807 @autocommit.setter
808 def autocommit(self, value):
809 if value:
810 self.isolation_level = "autocommit"
811 else:
812 self.isolation_level = self._isolation_setting
813
814 def ping(self):
815 try:
816 _ = self.await_(self._async_ping())
817 except Exception as error:
818 self._handle_exception(error)
819
820 async def _async_ping(self):
821 if self._transaction is None and self.isolation_level != "autocommit":
822 # create a tranasction explicitly to support pgbouncer
823 # transaction mode. See #10226
824 tr = self._connection.transaction()
825 await tr.start()
826 try:
827 await self._connection.fetchrow(";")
828 finally:
829 await tr.rollback()
830 else:
831 await self._connection.fetchrow(";")
832
833 def set_isolation_level(self, level):
834 if self._started:
835 self.rollback()
836 self.isolation_level = self._isolation_setting = level
837
838 async def _start_transaction(self):
839 if self.isolation_level == "autocommit":
840 return
841
842 try:
843 self._transaction = self._connection.transaction(
844 isolation=self.isolation_level,
845 readonly=self.readonly,
846 deferrable=self.deferrable,
847 )
848 await self._transaction.start()
849 except Exception as error:
850 self._handle_exception(error)
851 else:
852 self._started = True
853
854 def cursor(self, server_side=False):
855 if server_side:
856 return AsyncAdapt_asyncpg_ss_cursor(self)
857 else:
858 return AsyncAdapt_asyncpg_cursor(self)
859
860 async def _rollback_and_discard(self):
861 try:
862 await self._transaction.rollback()
863 finally:
864 # if asyncpg .rollback() was actually called, then whether or
865 # not it raised or succeeded, the transation is done, discard it
866 self._transaction = None
867 self._started = False
868
869 async def _commit_and_discard(self):
870 try:
871 await self._transaction.commit()
872 finally:
873 # if asyncpg .commit() was actually called, then whether or
874 # not it raised or succeeded, the transation is done, discard it
875 self._transaction = None
876 self._started = False
877
878 def rollback(self):
879 if self._started:
880 try:
881 self.await_(self._rollback_and_discard())
882 self._transaction = None
883 self._started = False
884 except Exception as error:
885 # don't dereference asyncpg transaction if we didn't
886 # actually try to call rollback() on it
887 self._handle_exception(error)
888
889 def commit(self):
890 if self._started:
891 try:
892 self.await_(self._commit_and_discard())
893 self._transaction = None
894 self._started = False
895 except Exception as error:
896 # don't dereference asyncpg transaction if we didn't
897 # actually try to call commit() on it
898 self._handle_exception(error)
899
900 def close(self):
901 self.rollback()
902
903 self.await_(self._connection.close())
904
905 def _terminate_handled_exceptions(self):
906 return super()._terminate_handled_exceptions() + (
907 self.dbapi.asyncpg.PostgresError,
908 )
909
910 async def _terminate_graceful_close(self) -> None:
911 # timeout added in asyncpg 0.14.0 December 2017
912 await self._connection.close(timeout=2)
913 self._started = False
914
915 def _terminate_force_close(self) -> None:
916 self._connection.terminate()
917 self._started = False
918
919 @staticmethod
920 def _default_name_func():
921 return None
922
923
924class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
925 __slots__ = ()
926
927 await_ = staticmethod(await_fallback)
928
929
930class AsyncAdapt_asyncpg_dbapi:
931 def __init__(self, asyncpg):
932 self.asyncpg = asyncpg
933 self.paramstyle = "numeric_dollar"
934
935 def connect(self, *arg, **kw):
936 async_fallback = kw.pop("async_fallback", False)
937 creator_fn = kw.pop("async_creator_fn", self.asyncpg.connect)
938 prepared_statement_cache_size = kw.pop(
939 "prepared_statement_cache_size", 100
940 )
941 prepared_statement_name_func = kw.pop(
942 "prepared_statement_name_func", None
943 )
944
945 if util.asbool(async_fallback):
946 return AsyncAdaptFallback_asyncpg_connection(
947 self,
948 await_fallback(creator_fn(*arg, **kw)),
949 prepared_statement_cache_size=prepared_statement_cache_size,
950 prepared_statement_name_func=prepared_statement_name_func,
951 )
952 else:
953 return AsyncAdapt_asyncpg_connection(
954 self,
955 await_only(creator_fn(*arg, **kw)),
956 prepared_statement_cache_size=prepared_statement_cache_size,
957 prepared_statement_name_func=prepared_statement_name_func,
958 )
959
960 class Error(Exception):
961 pass
962
963 class Warning(Exception): # noqa
964 pass
965
966 class InterfaceError(Error):
967 pass
968
969 class DatabaseError(Error):
970 pass
971
972 class InternalError(DatabaseError):
973 pass
974
975 class OperationalError(DatabaseError):
976 pass
977
978 class ProgrammingError(DatabaseError):
979 pass
980
981 class IntegrityError(DatabaseError):
982 pass
983
984 class DataError(DatabaseError):
985 pass
986
987 class NotSupportedError(DatabaseError):
988 pass
989
990 class InternalServerError(InternalError):
991 pass
992
993 class InvalidCachedStatementError(NotSupportedError):
994 def __init__(self, message):
995 super().__init__(
996 message + " (SQLAlchemy asyncpg dialect will now invalidate "
997 "all prepared caches in response to this exception)",
998 )
999
1000 # pep-249 datatype placeholders. As of SQLAlchemy 2.0 these aren't
1001 # used, however the test suite looks for these in a few cases.
1002 STRING = util.symbol("STRING")
1003 NUMBER = util.symbol("NUMBER")
1004 DATETIME = util.symbol("DATETIME")
1005
1006 @util.memoized_property
1007 def _asyncpg_error_translate(self):
1008 import asyncpg
1009
1010 return {
1011 asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501
1012 asyncpg.exceptions.PostgresError: self.Error,
1013 asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError,
1014 asyncpg.exceptions.InterfaceError: self.InterfaceError,
1015 asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501
1016 asyncpg.exceptions.InternalServerError: self.InternalServerError,
1017 }
1018
1019 def Binary(self, value):
1020 return value
1021
1022
1023class PGDialect_asyncpg(PGDialect):
1024 driver = "asyncpg"
1025 supports_statement_cache = True
1026
1027 supports_server_side_cursors = True
1028
1029 render_bind_cast = True
1030 has_terminate = True
1031
1032 default_paramstyle = "numeric_dollar"
1033 supports_sane_multi_rowcount = False
1034 execution_ctx_cls = PGExecutionContext_asyncpg
1035 statement_compiler = PGCompiler_asyncpg
1036 preparer = PGIdentifierPreparer_asyncpg
1037
1038 colspecs = util.update_copy(
1039 PGDialect.colspecs,
1040 {
1041 sqltypes.String: AsyncpgString,
1042 sqltypes.ARRAY: AsyncpgARRAY,
1043 BIT: AsyncpgBit,
1044 CITEXT: CITEXT,
1045 REGCONFIG: AsyncpgREGCONFIG,
1046 sqltypes.Time: AsyncpgTime,
1047 sqltypes.Date: AsyncpgDate,
1048 sqltypes.DateTime: AsyncpgDateTime,
1049 sqltypes.Interval: AsyncPgInterval,
1050 INTERVAL: AsyncPgInterval,
1051 sqltypes.Boolean: AsyncpgBoolean,
1052 sqltypes.Integer: AsyncpgInteger,
1053 sqltypes.SmallInteger: AsyncpgSmallInteger,
1054 sqltypes.BigInteger: AsyncpgBigInteger,
1055 sqltypes.Numeric: AsyncpgNumeric,
1056 sqltypes.Float: AsyncpgFloat,
1057 sqltypes.JSON: AsyncpgJSON,
1058 sqltypes.LargeBinary: AsyncpgByteA,
1059 json.JSONB: AsyncpgJSONB,
1060 sqltypes.JSON.JSONPathType: AsyncpgJSONPathType,
1061 sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType,
1062 sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType,
1063 sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType,
1064 sqltypes.Enum: AsyncPgEnum,
1065 OID: AsyncpgOID,
1066 REGCLASS: AsyncpgREGCLASS,
1067 sqltypes.CHAR: AsyncpgCHAR,
1068 ranges.AbstractSingleRange: _AsyncpgRange,
1069 ranges.AbstractMultiRange: _AsyncpgMultiRange,
1070 },
1071 )
1072 is_async = True
1073 _invalidate_schema_cache_asof = 0
1074
1075 def _invalidate_schema_cache(self):
1076 self._invalidate_schema_cache_asof = time.time()
1077
1078 @util.memoized_property
1079 def _dbapi_version(self):
1080 if self.dbapi and hasattr(self.dbapi, "__version__"):
1081 return tuple(
1082 [
1083 int(x)
1084 for x in re.findall(
1085 r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
1086 )
1087 ]
1088 )
1089 else:
1090 return (99, 99, 99)
1091
1092 @classmethod
1093 def import_dbapi(cls):
1094 return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg"))
1095
1096 @util.memoized_property
1097 def _isolation_lookup(self):
1098 return {
1099 "AUTOCOMMIT": "autocommit",
1100 "READ COMMITTED": "read_committed",
1101 "REPEATABLE READ": "repeatable_read",
1102 "SERIALIZABLE": "serializable",
1103 }
1104
1105 def get_isolation_level_values(self, dbapi_connection):
1106 return list(self._isolation_lookup)
1107
1108 def set_isolation_level(self, dbapi_connection, level):
1109 dbapi_connection.set_isolation_level(self._isolation_lookup[level])
1110
1111 def detect_autocommit_setting(self, dbapi_conn) -> bool:
1112 return bool(dbapi_conn.autocommit)
1113
1114 def set_readonly(self, connection, value):
1115 connection.readonly = value
1116
1117 def get_readonly(self, connection):
1118 return connection.readonly
1119
1120 def set_deferrable(self, connection, value):
1121 connection.deferrable = value
1122
1123 def get_deferrable(self, connection):
1124 return connection.deferrable
1125
1126 def do_terminate(self, dbapi_connection) -> None:
1127 dbapi_connection.terminate()
1128
1129 def create_connect_args(self, url):
1130 opts = url.translate_connect_args(username="user")
1131 multihosts, multiports = self._split_multihost_from_url(url)
1132
1133 opts.update(url.query)
1134
1135 if multihosts:
1136 assert multiports
1137 if len(multihosts) == 1:
1138 opts["host"] = multihosts[0]
1139 if multiports[0] is not None:
1140 opts["port"] = multiports[0]
1141 elif not all(multihosts):
1142 raise exc.ArgumentError(
1143 "All hosts are required to be present"
1144 " for asyncpg multiple host URL"
1145 )
1146 elif not all(multiports):
1147 raise exc.ArgumentError(
1148 "All ports are required to be present"
1149 " for asyncpg multiple host URL"
1150 )
1151 else:
1152 opts["host"] = list(multihosts)
1153 opts["port"] = list(multiports)
1154 else:
1155 util.coerce_kw_type(opts, "port", int)
1156 util.coerce_kw_type(opts, "prepared_statement_cache_size", int)
1157 return ([], opts)
1158
1159 def do_ping(self, dbapi_connection):
1160 dbapi_connection.ping()
1161 return True
1162
1163 @classmethod
1164 def get_pool_class(cls, url):
1165 async_fallback = url.query.get("async_fallback", False)
1166
1167 if util.asbool(async_fallback):
1168 return pool.FallbackAsyncAdaptedQueuePool
1169 else:
1170 return pool.AsyncAdaptedQueuePool
1171
1172 def is_disconnect(self, e, connection, cursor):
1173 if connection:
1174 return connection._connection.is_closed()
1175 else:
1176 return isinstance(
1177 e, self.dbapi.InterfaceError
1178 ) and "connection is closed" in str(e)
1179
1180 async def setup_asyncpg_json_codec(self, conn):
1181 """set up JSON codec for asyncpg.
1182
1183 This occurs for all new connections and
1184 can be overridden by third party dialects.
1185
1186 .. versionadded:: 1.4.27
1187
1188 """
1189
1190 asyncpg_connection = conn._connection
1191 deserializer = self._json_deserializer or _py_json.loads
1192
1193 def _json_decoder(bin_value):
1194 return deserializer(bin_value.decode())
1195
1196 await asyncpg_connection.set_type_codec(
1197 "json",
1198 encoder=str.encode,
1199 decoder=_json_decoder,
1200 schema="pg_catalog",
1201 format="binary",
1202 )
1203
1204 async def setup_asyncpg_jsonb_codec(self, conn):
1205 """set up JSONB codec for asyncpg.
1206
1207 This occurs for all new connections and
1208 can be overridden by third party dialects.
1209
1210 .. versionadded:: 1.4.27
1211
1212 """
1213
1214 asyncpg_connection = conn._connection
1215 deserializer = self._json_deserializer or _py_json.loads
1216
1217 def _jsonb_encoder(str_value):
1218 # \x01 is the prefix for jsonb used by PostgreSQL.
1219 # asyncpg requires it when format='binary'
1220 return b"\x01" + str_value.encode()
1221
1222 deserializer = self._json_deserializer or _py_json.loads
1223
1224 def _jsonb_decoder(bin_value):
1225 # the byte is the \x01 prefix for jsonb used by PostgreSQL.
1226 # asyncpg returns it when format='binary'
1227 return deserializer(bin_value[1:].decode())
1228
1229 await asyncpg_connection.set_type_codec(
1230 "jsonb",
1231 encoder=_jsonb_encoder,
1232 decoder=_jsonb_decoder,
1233 schema="pg_catalog",
1234 format="binary",
1235 )
1236
1237 async def _disable_asyncpg_inet_codecs(self, conn):
1238 asyncpg_connection = conn._connection
1239
1240 await asyncpg_connection.set_type_codec(
1241 "inet",
1242 encoder=lambda s: s,
1243 decoder=lambda s: s,
1244 schema="pg_catalog",
1245 format="text",
1246 )
1247
1248 await asyncpg_connection.set_type_codec(
1249 "cidr",
1250 encoder=lambda s: s,
1251 decoder=lambda s: s,
1252 schema="pg_catalog",
1253 format="text",
1254 )
1255
1256 def on_connect(self):
1257 """on_connect for asyncpg
1258
1259 A major component of this for asyncpg is to set up type decoders at the
1260 asyncpg level.
1261
1262 See https://github.com/MagicStack/asyncpg/issues/623 for
1263 notes on JSON/JSONB implementation.
1264
1265 """
1266
1267 super_connect = super().on_connect()
1268
1269 def connect(conn):
1270 conn.await_(self.setup_asyncpg_json_codec(conn))
1271 conn.await_(self.setup_asyncpg_jsonb_codec(conn))
1272
1273 if self._native_inet_types is False:
1274 conn.await_(self._disable_asyncpg_inet_codecs(conn))
1275 if super_connect is not None:
1276 super_connect(conn)
1277
1278 return connect
1279
1280 def get_driver_connection(self, connection):
1281 return connection._connection
1282
1283
1284dialect = PGDialect_asyncpg