1# dialects/postgresql/asyncpg.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
3# file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9r"""
10.. dialect:: postgresql+asyncpg
11 :name: asyncpg
12 :dbapi: asyncpg
13 :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...]
14 :url: https://magicstack.github.io/asyncpg/
15
16The asyncpg dialect is SQLAlchemy's first Python asyncio dialect.
17
18Using a special asyncio mediation layer, the asyncpg dialect is usable
19as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
20extension package.
21
22This dialect should normally be used only with the
23:func:`_asyncio.create_async_engine` engine creation function::
24
25 from sqlalchemy.ext.asyncio import create_async_engine
26
27 engine = create_async_engine(
28 "postgresql+asyncpg://user:pass@hostname/dbname"
29 )
30
31.. versionadded:: 1.4
32
33.. note::
34
35 By default asyncpg does not decode the ``json`` and ``jsonb`` types and
36 returns them as strings. SQLAlchemy sets default type decoder for ``json``
37 and ``jsonb`` types using the python builtin ``json.loads`` function.
38 The json implementation used can be changed by setting the attribute
39 ``json_deserializer`` when creating the engine with
40 :func:`create_engine` or :func:`create_async_engine`.
41
42.. _asyncpg_multihost:
43
44Multihost Connections
45--------------------------
46
47The asyncpg dialect features support for multiple fallback hosts in the
48same way as that of the psycopg2 and psycopg dialects. The
49syntax is the same,
50using ``host=<host>:<port>`` combinations as additional query string arguments;
51however, there is no default port, so all hosts must have a complete port number
52present, otherwise an exception is raised::
53
54 engine = create_async_engine(
55 "postgresql+asyncpg://user:password@/dbname?host=HostA:5432&host=HostB:5432&host=HostC:5432"
56 )
57
58For complete background on this syntax, see :ref:`psycopg2_multi_host`.
59
60.. versionadded:: 2.0.18
61
62.. seealso::
63
64 :ref:`psycopg2_multi_host`
65
66.. _asyncpg_prepared_statement_cache:
67
68Prepared Statement Cache
69--------------------------
70
71The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()``
72for all statements. The prepared statement objects are cached after
73construction which appears to grant a 10% or more performance improvement for
74statement invocation. The cache is on a per-DBAPI connection basis, which
75means that the primary storage for prepared statements is within DBAPI
76connections pooled within the connection pool. The size of this cache
77defaults to 100 statements per DBAPI connection and may be adjusted using the
78``prepared_statement_cache_size`` DBAPI argument (note that while this argument
79is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the
80asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
81argument)::
82
83
84 engine = create_async_engine(
85 "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500"
86 )
87
88To disable the prepared statement cache, use a value of zero::
89
90 engine = create_async_engine(
91 "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0"
92 )
93
94.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
95
96
97.. warning:: The ``asyncpg`` database driver necessarily uses caches for
98 PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes
99 such as ``ENUM`` objects are changed via DDL operations. Additionally,
100 prepared statements themselves which are optionally cached by SQLAlchemy's
101 driver as described above may also become "stale" when DDL has been emitted
102 to the PostgreSQL database which modifies the tables or other objects
103 involved in a particular prepared statement.
104
105 The SQLAlchemy asyncpg dialect will invalidate these caches within its local
106 process when statements that represent DDL are emitted on a local
107 connection, but this is only controllable within a single Python process /
108 database engine. If DDL changes are made from other database engines
109 and/or processes, a running application may encounter asyncpg exceptions
110 ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup
111 failed for type <oid>")`` if it refers to pooled database connections which
112 operated upon the previous structures. The SQLAlchemy asyncpg dialect will
113 recover from these error cases when the driver raises these exceptions by
114 clearing its internal caches as well as those of the asyncpg driver in
115 response to them, but cannot prevent them from being raised in the first
116 place if the cached prepared statement or asyncpg type caches have gone
117 stale, nor can it retry the statement as the PostgreSQL transaction is
118 invalidated when these errors occur.
119
120.. _asyncpg_prepared_statement_name:
121
122Prepared Statement Name with PGBouncer
123--------------------------------------
124
125By default, asyncpg enumerates prepared statements in numeric order, which
126can lead to errors if a name has already been taken for another prepared
127statement. This issue can arise if your application uses database proxies
128such as PgBouncer to handle connections. One possible workaround is to
129use dynamic prepared statement names, which asyncpg now supports through
130an optional ``name`` value for the statement name. This allows you to
131generate your own unique names that won't conflict with existing ones.
132To achieve this, you can provide a function that will be called every time
133a prepared statement is prepared::
134
135 from uuid import uuid4
136
137 engine = create_async_engine(
138 "postgresql+asyncpg://user:pass@somepgbouncer/dbname",
139 poolclass=NullPool,
140 connect_args={
141 "prepared_statement_name_func": lambda: f"__asyncpg_{uuid4()}__",
142 },
143 )
144
145.. seealso::
146
147 https://github.com/MagicStack/asyncpg/issues/837
148
149 https://github.com/sqlalchemy/sqlalchemy/issues/6467
150
151.. warning:: When using PGBouncer, to prevent a buildup of useless prepared statements in
152 your application, it's important to use the :class:`.NullPool` pool
153 class, and to configure PgBouncer to use `DISCARD <https://www.postgresql.org/docs/current/sql-discard.html>`_
154 when returning connections. The DISCARD command is used to release resources held by the db connection,
155 including prepared statements. Without proper setup, prepared statements can
156 accumulate quickly and cause performance issues.
157
158Disabling the PostgreSQL JIT to improve ENUM datatype handling
159---------------------------------------------------------------
160
161Asyncpg has an `issue <https://github.com/MagicStack/asyncpg/issues/727>`_ when
162using PostgreSQL ENUM datatypes, where upon the creation of new database
163connections, an expensive query may be emitted in order to retrieve metadata
164regarding custom types which has been shown to negatively affect performance.
165To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the
166client using this setting passed to :func:`_asyncio.create_async_engine`::
167
168 engine = create_async_engine(
169 "postgresql+asyncpg://user:password@localhost/tmp",
170 connect_args={"server_settings": {"jit": "off"}},
171 )
172
173.. seealso::
174
175 https://github.com/MagicStack/asyncpg/issues/727
176
177""" # noqa
178
179from __future__ import annotations
180
181from collections import deque
182import decimal
183import json as _py_json
184import re
185import time
186
187from . import json
188from . import ranges
189from .array import ARRAY as PGARRAY
190from .base import _DECIMAL_TYPES
191from .base import _FLOAT_TYPES
192from .base import _INT_TYPES
193from .base import ENUM
194from .base import INTERVAL
195from .base import OID
196from .base import PGCompiler
197from .base import PGDialect
198from .base import PGExecutionContext
199from .base import PGIdentifierPreparer
200from .base import REGCLASS
201from .base import REGCONFIG
202from .types import BIT
203from .types import BYTEA
204from .types import CITEXT
205from ... import exc
206from ... import pool
207from ... import util
208from ...engine import AdaptedConnection
209from ...engine import processors
210from ...sql import sqltypes
211from ...util.concurrency import asyncio
212from ...util.concurrency import await_fallback
213from ...util.concurrency import await_only
214
215
216class AsyncpgARRAY(PGARRAY):
217 render_bind_cast = True
218
219
220class AsyncpgString(sqltypes.String):
221 render_bind_cast = True
222
223
224class AsyncpgREGCONFIG(REGCONFIG):
225 render_bind_cast = True
226
227
228class AsyncpgTime(sqltypes.Time):
229 render_bind_cast = True
230
231
232class AsyncpgBit(BIT):
233 render_bind_cast = True
234
235
236class AsyncpgByteA(BYTEA):
237 render_bind_cast = True
238
239
240class AsyncpgDate(sqltypes.Date):
241 render_bind_cast = True
242
243
244class AsyncpgDateTime(sqltypes.DateTime):
245 render_bind_cast = True
246
247
248class AsyncpgBoolean(sqltypes.Boolean):
249 render_bind_cast = True
250
251
252class AsyncPgInterval(INTERVAL):
253 render_bind_cast = True
254
255 @classmethod
256 def adapt_emulated_to_native(cls, interval, **kw):
257 return AsyncPgInterval(precision=interval.second_precision)
258
259
260class AsyncPgEnum(ENUM):
261 render_bind_cast = True
262
263
264class AsyncpgInteger(sqltypes.Integer):
265 render_bind_cast = True
266
267
268class AsyncpgSmallInteger(sqltypes.SmallInteger):
269 render_bind_cast = True
270
271
272class AsyncpgBigInteger(sqltypes.BigInteger):
273 render_bind_cast = True
274
275
276class AsyncpgJSON(json.JSON):
277 def result_processor(self, dialect, coltype):
278 return None
279
280
281class AsyncpgJSONB(json.JSONB):
282 def result_processor(self, dialect, coltype):
283 return None
284
285
286class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
287 pass
288
289
290class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
291 __visit_name__ = "json_int_index"
292
293 render_bind_cast = True
294
295
296class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
297 __visit_name__ = "json_str_index"
298
299 render_bind_cast = True
300
301
302class AsyncpgJSONPathType(json.JSONPathType):
303 def bind_processor(self, dialect):
304 def process(value):
305 if isinstance(value, str):
306 # If it's already a string assume that it's in json path
307 # format. This allows using cast with json paths literals
308 return value
309 elif value:
310 tokens = [str(elem) for elem in value]
311 return tokens
312 else:
313 return []
314
315 return process
316
317
318class AsyncpgNumeric(sqltypes.Numeric):
319 render_bind_cast = True
320
321 def bind_processor(self, dialect):
322 return None
323
324 def result_processor(self, dialect, coltype):
325 if self.asdecimal:
326 if coltype in _FLOAT_TYPES:
327 return processors.to_decimal_processor_factory(
328 decimal.Decimal, self._effective_decimal_return_scale
329 )
330 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
331 # pg8000 returns Decimal natively for 1700
332 return None
333 else:
334 raise exc.InvalidRequestError(
335 "Unknown PG numeric type: %d" % coltype
336 )
337 else:
338 if coltype in _FLOAT_TYPES:
339 # pg8000 returns float natively for 701
340 return None
341 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
342 return processors.to_float
343 else:
344 raise exc.InvalidRequestError(
345 "Unknown PG numeric type: %d" % coltype
346 )
347
348
349class AsyncpgFloat(AsyncpgNumeric, sqltypes.Float):
350 __visit_name__ = "float"
351 render_bind_cast = True
352
353
354class AsyncpgREGCLASS(REGCLASS):
355 render_bind_cast = True
356
357
358class AsyncpgOID(OID):
359 render_bind_cast = True
360
361
362class AsyncpgCHAR(sqltypes.CHAR):
363 render_bind_cast = True
364
365
366class _AsyncpgRange(ranges.AbstractSingleRangeImpl):
367 def bind_processor(self, dialect):
368 asyncpg_Range = dialect.dbapi.asyncpg.Range
369
370 def to_range(value):
371 if isinstance(value, ranges.Range):
372 value = asyncpg_Range(
373 value.lower,
374 value.upper,
375 lower_inc=value.bounds[0] == "[",
376 upper_inc=value.bounds[1] == "]",
377 empty=value.empty,
378 )
379 return value
380
381 return to_range
382
383 def result_processor(self, dialect, coltype):
384 def to_range(value):
385 if value is not None:
386 empty = value.isempty
387 value = ranges.Range(
388 value.lower,
389 value.upper,
390 bounds=f"{'[' if empty or value.lower_inc else '('}" # type: ignore # noqa: E501
391 f"{']' if not empty and value.upper_inc else ')'}",
392 empty=empty,
393 )
394 return value
395
396 return to_range
397
398
399class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
400 def bind_processor(self, dialect):
401 asyncpg_Range = dialect.dbapi.asyncpg.Range
402
403 NoneType = type(None)
404
405 def to_range(value):
406 if isinstance(value, (str, NoneType)):
407 return value
408
409 def to_range(value):
410 if isinstance(value, ranges.Range):
411 value = asyncpg_Range(
412 value.lower,
413 value.upper,
414 lower_inc=value.bounds[0] == "[",
415 upper_inc=value.bounds[1] == "]",
416 empty=value.empty,
417 )
418 return value
419
420 return [to_range(element) for element in value]
421
422 return to_range
423
424 def result_processor(self, dialect, coltype):
425 def to_range_array(value):
426 def to_range(rvalue):
427 if rvalue is not None:
428 empty = rvalue.isempty
429 rvalue = ranges.Range(
430 rvalue.lower,
431 rvalue.upper,
432 bounds=f"{'[' if empty or rvalue.lower_inc else '('}" # type: ignore # noqa: E501
433 f"{']' if not empty and rvalue.upper_inc else ')'}",
434 empty=empty,
435 )
436 return rvalue
437
438 if value is not None:
439 value = ranges.MultiRange(to_range(elem) for elem in value)
440
441 return value
442
443 return to_range_array
444
445
446class PGExecutionContext_asyncpg(PGExecutionContext):
447 def handle_dbapi_exception(self, e):
448 if isinstance(
449 e,
450 (
451 self.dialect.dbapi.InvalidCachedStatementError,
452 self.dialect.dbapi.InternalServerError,
453 ),
454 ):
455 self.dialect._invalidate_schema_cache()
456
457 def pre_exec(self):
458 if self.isddl:
459 self.dialect._invalidate_schema_cache()
460
461 self.cursor._invalidate_schema_cache_asof = (
462 self.dialect._invalidate_schema_cache_asof
463 )
464
465 if not self.compiled:
466 return
467
468 def create_server_side_cursor(self):
469 return self._dbapi_connection.cursor(server_side=True)
470
471
472class PGCompiler_asyncpg(PGCompiler):
473 pass
474
475
476class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer):
477 pass
478
479
480class AsyncAdapt_asyncpg_cursor:
481 __slots__ = (
482 "_adapt_connection",
483 "_connection",
484 "_rows",
485 "description",
486 "arraysize",
487 "rowcount",
488 "_cursor",
489 "_invalidate_schema_cache_asof",
490 )
491
492 server_side = False
493
494 def __init__(self, adapt_connection):
495 self._adapt_connection = adapt_connection
496 self._connection = adapt_connection._connection
497 self._rows = deque()
498 self._cursor = None
499 self.description = None
500 self.arraysize = 1
501 self.rowcount = -1
502 self._invalidate_schema_cache_asof = 0
503
504 def close(self):
505 self._rows.clear()
506
507 def _handle_exception(self, error):
508 self._adapt_connection._handle_exception(error)
509
510 async def _prepare_and_execute(self, operation, parameters):
511 adapt_connection = self._adapt_connection
512
513 async with adapt_connection._execute_mutex:
514 if not adapt_connection._started:
515 await adapt_connection._start_transaction()
516
517 if parameters is None:
518 parameters = ()
519
520 try:
521 prepared_stmt, attributes = await adapt_connection._prepare(
522 operation, self._invalidate_schema_cache_asof
523 )
524
525 if attributes:
526 self.description = [
527 (
528 attr.name,
529 attr.type.oid,
530 None,
531 None,
532 None,
533 None,
534 None,
535 )
536 for attr in attributes
537 ]
538 else:
539 self.description = None
540
541 if self.server_side:
542 self._cursor = await prepared_stmt.cursor(*parameters)
543 self.rowcount = -1
544 else:
545 self._rows = deque(await prepared_stmt.fetch(*parameters))
546 status = prepared_stmt.get_statusmsg()
547
548 reg = re.match(
549 r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)",
550 status or "",
551 )
552 if reg:
553 self.rowcount = int(reg.group(1))
554 else:
555 self.rowcount = -1
556
557 except Exception as error:
558 self._handle_exception(error)
559
560 async def _executemany(self, operation, seq_of_parameters):
561 adapt_connection = self._adapt_connection
562
563 self.description = None
564 async with adapt_connection._execute_mutex:
565 await adapt_connection._check_type_cache_invalidation(
566 self._invalidate_schema_cache_asof
567 )
568
569 if not adapt_connection._started:
570 await adapt_connection._start_transaction()
571
572 try:
573 return await self._connection.executemany(
574 operation, seq_of_parameters
575 )
576 except Exception as error:
577 self._handle_exception(error)
578
579 def execute(self, operation, parameters=None):
580 self._adapt_connection.await_(
581 self._prepare_and_execute(operation, parameters)
582 )
583
584 def executemany(self, operation, seq_of_parameters):
585 return self._adapt_connection.await_(
586 self._executemany(operation, seq_of_parameters)
587 )
588
589 def setinputsizes(self, *inputsizes):
590 raise NotImplementedError()
591
592 def __iter__(self):
593 while self._rows:
594 yield self._rows.popleft()
595
596 def fetchone(self):
597 if self._rows:
598 return self._rows.popleft()
599 else:
600 return None
601
602 def fetchmany(self, size=None):
603 if size is None:
604 size = self.arraysize
605
606 rr = self._rows
607 return [rr.popleft() for _ in range(min(size, len(rr)))]
608
609 def fetchall(self):
610 retval = list(self._rows)
611 self._rows.clear()
612 return retval
613
614
615class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
616 server_side = True
617 __slots__ = ("_rowbuffer",)
618
619 def __init__(self, adapt_connection):
620 super().__init__(adapt_connection)
621 self._rowbuffer = deque()
622
623 def close(self):
624 self._cursor = None
625 self._rowbuffer.clear()
626
627 def _buffer_rows(self):
628 assert self._cursor is not None
629 new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
630 self._rowbuffer.extend(new_rows)
631
632 def __aiter__(self):
633 return self
634
635 async def __anext__(self):
636 while True:
637 while self._rowbuffer:
638 yield self._rowbuffer.popleft()
639
640 self._buffer_rows()
641 if not self._rowbuffer:
642 break
643
644 def fetchone(self):
645 if not self._rowbuffer:
646 self._buffer_rows()
647 if not self._rowbuffer:
648 return None
649 return self._rowbuffer.popleft()
650
651 def fetchmany(self, size=None):
652 if size is None:
653 return self.fetchall()
654
655 if not self._rowbuffer:
656 self._buffer_rows()
657
658 assert self._cursor is not None
659 rb = self._rowbuffer
660 lb = len(rb)
661 if size > lb:
662 rb.extend(
663 self._adapt_connection.await_(self._cursor.fetch(size - lb))
664 )
665
666 return [rb.popleft() for _ in range(min(size, len(rb)))]
667
668 def fetchall(self):
669 ret = list(self._rowbuffer)
670 ret.extend(self._adapt_connection.await_(self._all()))
671 self._rowbuffer.clear()
672 return ret
673
674 async def _all(self):
675 rows = []
676
677 # TODO: looks like we have to hand-roll some kind of batching here.
678 # hardcoding for the moment but this should be improved.
679 while True:
680 batch = await self._cursor.fetch(1000)
681 if batch:
682 rows.extend(batch)
683 continue
684 else:
685 break
686 return rows
687
688 def executemany(self, operation, seq_of_parameters):
689 raise NotImplementedError(
690 "server side cursor doesn't support executemany yet"
691 )
692
693
694class AsyncAdapt_asyncpg_connection(AdaptedConnection):
695 __slots__ = (
696 "dbapi",
697 "isolation_level",
698 "_isolation_setting",
699 "readonly",
700 "deferrable",
701 "_transaction",
702 "_started",
703 "_prepared_statement_cache",
704 "_prepared_statement_name_func",
705 "_invalidate_schema_cache_asof",
706 "_execute_mutex",
707 )
708
709 await_ = staticmethod(await_only)
710
711 def __init__(
712 self,
713 dbapi,
714 connection,
715 prepared_statement_cache_size=100,
716 prepared_statement_name_func=None,
717 ):
718 self.dbapi = dbapi
719 self._connection = connection
720 self.isolation_level = self._isolation_setting = None
721 self.readonly = False
722 self.deferrable = False
723 self._transaction = None
724 self._started = False
725 self._invalidate_schema_cache_asof = time.time()
726 self._execute_mutex = asyncio.Lock()
727
728 if prepared_statement_cache_size:
729 self._prepared_statement_cache = util.LRUCache(
730 prepared_statement_cache_size
731 )
732 else:
733 self._prepared_statement_cache = None
734
735 if prepared_statement_name_func:
736 self._prepared_statement_name_func = prepared_statement_name_func
737 else:
738 self._prepared_statement_name_func = self._default_name_func
739
740 async def _check_type_cache_invalidation(self, invalidate_timestamp):
741 if invalidate_timestamp > self._invalidate_schema_cache_asof:
742 await self._connection.reload_schema_state()
743 self._invalidate_schema_cache_asof = invalidate_timestamp
744
745 async def _prepare(self, operation, invalidate_timestamp):
746 await self._check_type_cache_invalidation(invalidate_timestamp)
747
748 cache = self._prepared_statement_cache
749 if cache is None:
750 prepared_stmt = await self._connection.prepare(
751 operation, name=self._prepared_statement_name_func()
752 )
753 attributes = prepared_stmt.get_attributes()
754 return prepared_stmt, attributes
755
756 # asyncpg uses a type cache for the "attributes" which seems to go
757 # stale independently of the PreparedStatement itself, so place that
758 # collection in the cache as well.
759 if operation in cache:
760 prepared_stmt, attributes, cached_timestamp = cache[operation]
761
762 # preparedstatements themselves also go stale for certain DDL
763 # changes such as size of a VARCHAR changing, so there is also
764 # a cross-connection invalidation timestamp
765 if cached_timestamp > invalidate_timestamp:
766 return prepared_stmt, attributes
767
768 prepared_stmt = await self._connection.prepare(
769 operation, name=self._prepared_statement_name_func()
770 )
771 attributes = prepared_stmt.get_attributes()
772 cache[operation] = (prepared_stmt, attributes, time.time())
773
774 return prepared_stmt, attributes
775
776 def _handle_exception(self, error):
777 if self._connection.is_closed():
778 self._transaction = None
779 self._started = False
780
781 if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
782 exception_mapping = self.dbapi._asyncpg_error_translate
783
784 for super_ in type(error).__mro__:
785 if super_ in exception_mapping:
786 translated_error = exception_mapping[super_](
787 "%s: %s" % (type(error), error)
788 )
789 translated_error.pgcode = translated_error.sqlstate = (
790 getattr(error, "sqlstate", None)
791 )
792 raise translated_error from error
793 else:
794 raise error
795 else:
796 raise error
797
798 @property
799 def autocommit(self):
800 return self.isolation_level == "autocommit"
801
802 @autocommit.setter
803 def autocommit(self, value):
804 if value:
805 self.isolation_level = "autocommit"
806 else:
807 self.isolation_level = self._isolation_setting
808
809 def ping(self):
810 try:
811 _ = self.await_(self._async_ping())
812 except Exception as error:
813 self._handle_exception(error)
814
815 async def _async_ping(self):
816 if self._transaction is None and self.isolation_level != "autocommit":
817 # create a tranasction explicitly to support pgbouncer
818 # transaction mode. See #10226
819 tr = self._connection.transaction()
820 await tr.start()
821 try:
822 await self._connection.fetchrow(";")
823 finally:
824 await tr.rollback()
825 else:
826 await self._connection.fetchrow(";")
827
828 def set_isolation_level(self, level):
829 if self._started:
830 self.rollback()
831 self.isolation_level = self._isolation_setting = level
832
833 async def _start_transaction(self):
834 if self.isolation_level == "autocommit":
835 return
836
837 try:
838 self._transaction = self._connection.transaction(
839 isolation=self.isolation_level,
840 readonly=self.readonly,
841 deferrable=self.deferrable,
842 )
843 await self._transaction.start()
844 except Exception as error:
845 self._handle_exception(error)
846 else:
847 self._started = True
848
849 def cursor(self, server_side=False):
850 if server_side:
851 return AsyncAdapt_asyncpg_ss_cursor(self)
852 else:
853 return AsyncAdapt_asyncpg_cursor(self)
854
855 async def _rollback_and_discard(self):
856 try:
857 await self._transaction.rollback()
858 finally:
859 # if asyncpg .rollback() was actually called, then whether or
860 # not it raised or succeeded, the transation is done, discard it
861 self._transaction = None
862 self._started = False
863
864 async def _commit_and_discard(self):
865 try:
866 await self._transaction.commit()
867 finally:
868 # if asyncpg .commit() was actually called, then whether or
869 # not it raised or succeeded, the transation is done, discard it
870 self._transaction = None
871 self._started = False
872
873 def rollback(self):
874 if self._started:
875 try:
876 self.await_(self._rollback_and_discard())
877 self._transaction = None
878 self._started = False
879 except Exception as error:
880 # don't dereference asyncpg transaction if we didn't
881 # actually try to call rollback() on it
882 self._handle_exception(error)
883
884 def commit(self):
885 if self._started:
886 try:
887 self.await_(self._commit_and_discard())
888 self._transaction = None
889 self._started = False
890 except Exception as error:
891 # don't dereference asyncpg transaction if we didn't
892 # actually try to call commit() on it
893 self._handle_exception(error)
894
895 def close(self):
896 self.rollback()
897
898 self.await_(self._connection.close())
899
900 def terminate(self):
901 if util.concurrency.in_greenlet():
902 # in a greenlet; this is the connection was invalidated
903 # case.
904 try:
905 # try to gracefully close; see #10717
906 # timeout added in asyncpg 0.14.0 December 2017
907 self.await_(asyncio.shield(self._connection.close(timeout=2)))
908 except (
909 asyncio.TimeoutError,
910 asyncio.CancelledError,
911 OSError,
912 self.dbapi.asyncpg.PostgresError,
913 ) as e:
914 # in the case where we are recycling an old connection
915 # that may have already been disconnected, close() will
916 # fail with the above timeout. in this case, terminate
917 # the connection without any further waiting.
918 # see issue #8419
919 self._connection.terminate()
920 if isinstance(e, asyncio.CancelledError):
921 # re-raise CancelledError if we were cancelled
922 raise
923 else:
924 # not in a greenlet; this is the gc cleanup case
925 self._connection.terminate()
926 self._started = False
927
928 @staticmethod
929 def _default_name_func():
930 return None
931
932
933class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
934 __slots__ = ()
935
936 await_ = staticmethod(await_fallback)
937
938
939class AsyncAdapt_asyncpg_dbapi:
940 def __init__(self, asyncpg):
941 self.asyncpg = asyncpg
942 self.paramstyle = "numeric_dollar"
943
944 def connect(self, *arg, **kw):
945 async_fallback = kw.pop("async_fallback", False)
946 creator_fn = kw.pop("async_creator_fn", self.asyncpg.connect)
947 prepared_statement_cache_size = kw.pop(
948 "prepared_statement_cache_size", 100
949 )
950 prepared_statement_name_func = kw.pop(
951 "prepared_statement_name_func", None
952 )
953
954 if util.asbool(async_fallback):
955 return AsyncAdaptFallback_asyncpg_connection(
956 self,
957 await_fallback(creator_fn(*arg, **kw)),
958 prepared_statement_cache_size=prepared_statement_cache_size,
959 prepared_statement_name_func=prepared_statement_name_func,
960 )
961 else:
962 return AsyncAdapt_asyncpg_connection(
963 self,
964 await_only(creator_fn(*arg, **kw)),
965 prepared_statement_cache_size=prepared_statement_cache_size,
966 prepared_statement_name_func=prepared_statement_name_func,
967 )
968
969 class Error(Exception):
970 pass
971
972 class Warning(Exception): # noqa
973 pass
974
975 class InterfaceError(Error):
976 pass
977
978 class DatabaseError(Error):
979 pass
980
981 class InternalError(DatabaseError):
982 pass
983
984 class OperationalError(DatabaseError):
985 pass
986
987 class ProgrammingError(DatabaseError):
988 pass
989
990 class IntegrityError(DatabaseError):
991 pass
992
993 class DataError(DatabaseError):
994 pass
995
996 class NotSupportedError(DatabaseError):
997 pass
998
999 class InternalServerError(InternalError):
1000 pass
1001
1002 class InvalidCachedStatementError(NotSupportedError):
1003 def __init__(self, message):
1004 super().__init__(
1005 message + " (SQLAlchemy asyncpg dialect will now invalidate "
1006 "all prepared caches in response to this exception)",
1007 )
1008
1009 # pep-249 datatype placeholders. As of SQLAlchemy 2.0 these aren't
1010 # used, however the test suite looks for these in a few cases.
1011 STRING = util.symbol("STRING")
1012 NUMBER = util.symbol("NUMBER")
1013 DATETIME = util.symbol("DATETIME")
1014
1015 @util.memoized_property
1016 def _asyncpg_error_translate(self):
1017 import asyncpg
1018
1019 return {
1020 asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501
1021 asyncpg.exceptions.PostgresError: self.Error,
1022 asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError,
1023 asyncpg.exceptions.InterfaceError: self.InterfaceError,
1024 asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501
1025 asyncpg.exceptions.InternalServerError: self.InternalServerError,
1026 }
1027
1028 def Binary(self, value):
1029 return value
1030
1031
1032class PGDialect_asyncpg(PGDialect):
1033 driver = "asyncpg"
1034 supports_statement_cache = True
1035
1036 supports_server_side_cursors = True
1037
1038 render_bind_cast = True
1039 has_terminate = True
1040
1041 default_paramstyle = "numeric_dollar"
1042 supports_sane_multi_rowcount = False
1043 execution_ctx_cls = PGExecutionContext_asyncpg
1044 statement_compiler = PGCompiler_asyncpg
1045 preparer = PGIdentifierPreparer_asyncpg
1046
1047 colspecs = util.update_copy(
1048 PGDialect.colspecs,
1049 {
1050 sqltypes.String: AsyncpgString,
1051 sqltypes.ARRAY: AsyncpgARRAY,
1052 BIT: AsyncpgBit,
1053 CITEXT: CITEXT,
1054 REGCONFIG: AsyncpgREGCONFIG,
1055 sqltypes.Time: AsyncpgTime,
1056 sqltypes.Date: AsyncpgDate,
1057 sqltypes.DateTime: AsyncpgDateTime,
1058 sqltypes.Interval: AsyncPgInterval,
1059 INTERVAL: AsyncPgInterval,
1060 sqltypes.Boolean: AsyncpgBoolean,
1061 sqltypes.Integer: AsyncpgInteger,
1062 sqltypes.SmallInteger: AsyncpgSmallInteger,
1063 sqltypes.BigInteger: AsyncpgBigInteger,
1064 sqltypes.Numeric: AsyncpgNumeric,
1065 sqltypes.Float: AsyncpgFloat,
1066 sqltypes.JSON: AsyncpgJSON,
1067 sqltypes.LargeBinary: AsyncpgByteA,
1068 json.JSONB: AsyncpgJSONB,
1069 sqltypes.JSON.JSONPathType: AsyncpgJSONPathType,
1070 sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType,
1071 sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType,
1072 sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType,
1073 sqltypes.Enum: AsyncPgEnum,
1074 OID: AsyncpgOID,
1075 REGCLASS: AsyncpgREGCLASS,
1076 sqltypes.CHAR: AsyncpgCHAR,
1077 ranges.AbstractSingleRange: _AsyncpgRange,
1078 ranges.AbstractMultiRange: _AsyncpgMultiRange,
1079 },
1080 )
1081 is_async = True
1082 _invalidate_schema_cache_asof = 0
1083
1084 def _invalidate_schema_cache(self):
1085 self._invalidate_schema_cache_asof = time.time()
1086
1087 @util.memoized_property
1088 def _dbapi_version(self):
1089 if self.dbapi and hasattr(self.dbapi, "__version__"):
1090 return tuple(
1091 [
1092 int(x)
1093 for x in re.findall(
1094 r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
1095 )
1096 ]
1097 )
1098 else:
1099 return (99, 99, 99)
1100
1101 @classmethod
1102 def import_dbapi(cls):
1103 return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg"))
1104
1105 @util.memoized_property
1106 def _isolation_lookup(self):
1107 return {
1108 "AUTOCOMMIT": "autocommit",
1109 "READ COMMITTED": "read_committed",
1110 "REPEATABLE READ": "repeatable_read",
1111 "SERIALIZABLE": "serializable",
1112 }
1113
1114 def get_isolation_level_values(self, dbapi_connection):
1115 return list(self._isolation_lookup)
1116
1117 def set_isolation_level(self, dbapi_connection, level):
1118 dbapi_connection.set_isolation_level(self._isolation_lookup[level])
1119
1120 def detect_autocommit_setting(self, dbapi_conn) -> bool:
1121 return bool(dbapi_conn.autocommit)
1122
1123 def set_readonly(self, connection, value):
1124 connection.readonly = value
1125
1126 def get_readonly(self, connection):
1127 return connection.readonly
1128
1129 def set_deferrable(self, connection, value):
1130 connection.deferrable = value
1131
1132 def get_deferrable(self, connection):
1133 return connection.deferrable
1134
1135 def do_terminate(self, dbapi_connection) -> None:
1136 dbapi_connection.terminate()
1137
1138 def create_connect_args(self, url):
1139 opts = url.translate_connect_args(username="user")
1140 multihosts, multiports = self._split_multihost_from_url(url)
1141
1142 opts.update(url.query)
1143
1144 if multihosts:
1145 assert multiports
1146 if len(multihosts) == 1:
1147 opts["host"] = multihosts[0]
1148 if multiports[0] is not None:
1149 opts["port"] = multiports[0]
1150 elif not all(multihosts):
1151 raise exc.ArgumentError(
1152 "All hosts are required to be present"
1153 " for asyncpg multiple host URL"
1154 )
1155 elif not all(multiports):
1156 raise exc.ArgumentError(
1157 "All ports are required to be present"
1158 " for asyncpg multiple host URL"
1159 )
1160 else:
1161 opts["host"] = list(multihosts)
1162 opts["port"] = list(multiports)
1163 else:
1164 util.coerce_kw_type(opts, "port", int)
1165 util.coerce_kw_type(opts, "prepared_statement_cache_size", int)
1166 return ([], opts)
1167
1168 def do_ping(self, dbapi_connection):
1169 dbapi_connection.ping()
1170 return True
1171
1172 @classmethod
1173 def get_pool_class(cls, url):
1174 async_fallback = url.query.get("async_fallback", False)
1175
1176 if util.asbool(async_fallback):
1177 return pool.FallbackAsyncAdaptedQueuePool
1178 else:
1179 return pool.AsyncAdaptedQueuePool
1180
1181 def is_disconnect(self, e, connection, cursor):
1182 if connection:
1183 return connection._connection.is_closed()
1184 else:
1185 return isinstance(
1186 e, self.dbapi.InterfaceError
1187 ) and "connection is closed" in str(e)
1188
1189 async def setup_asyncpg_json_codec(self, conn):
1190 """set up JSON codec for asyncpg.
1191
1192 This occurs for all new connections and
1193 can be overridden by third party dialects.
1194
1195 .. versionadded:: 1.4.27
1196
1197 """
1198
1199 asyncpg_connection = conn._connection
1200 deserializer = self._json_deserializer or _py_json.loads
1201
1202 def _json_decoder(bin_value):
1203 return deserializer(bin_value.decode())
1204
1205 await asyncpg_connection.set_type_codec(
1206 "json",
1207 encoder=str.encode,
1208 decoder=_json_decoder,
1209 schema="pg_catalog",
1210 format="binary",
1211 )
1212
1213 async def setup_asyncpg_jsonb_codec(self, conn):
1214 """set up JSONB codec for asyncpg.
1215
1216 This occurs for all new connections and
1217 can be overridden by third party dialects.
1218
1219 .. versionadded:: 1.4.27
1220
1221 """
1222
1223 asyncpg_connection = conn._connection
1224 deserializer = self._json_deserializer or _py_json.loads
1225
1226 def _jsonb_encoder(str_value):
1227 # \x01 is the prefix for jsonb used by PostgreSQL.
1228 # asyncpg requires it when format='binary'
1229 return b"\x01" + str_value.encode()
1230
1231 deserializer = self._json_deserializer or _py_json.loads
1232
1233 def _jsonb_decoder(bin_value):
1234 # the byte is the \x01 prefix for jsonb used by PostgreSQL.
1235 # asyncpg returns it when format='binary'
1236 return deserializer(bin_value[1:].decode())
1237
1238 await asyncpg_connection.set_type_codec(
1239 "jsonb",
1240 encoder=_jsonb_encoder,
1241 decoder=_jsonb_decoder,
1242 schema="pg_catalog",
1243 format="binary",
1244 )
1245
1246 async def _disable_asyncpg_inet_codecs(self, conn):
1247 asyncpg_connection = conn._connection
1248
1249 await asyncpg_connection.set_type_codec(
1250 "inet",
1251 encoder=lambda s: s,
1252 decoder=lambda s: s,
1253 schema="pg_catalog",
1254 format="text",
1255 )
1256
1257 await asyncpg_connection.set_type_codec(
1258 "cidr",
1259 encoder=lambda s: s,
1260 decoder=lambda s: s,
1261 schema="pg_catalog",
1262 format="text",
1263 )
1264
1265 def on_connect(self):
1266 """on_connect for asyncpg
1267
1268 A major component of this for asyncpg is to set up type decoders at the
1269 asyncpg level.
1270
1271 See https://github.com/MagicStack/asyncpg/issues/623 for
1272 notes on JSON/JSONB implementation.
1273
1274 """
1275
1276 super_connect = super().on_connect()
1277
1278 def connect(conn):
1279 conn.await_(self.setup_asyncpg_json_codec(conn))
1280 conn.await_(self.setup_asyncpg_jsonb_codec(conn))
1281
1282 if self._native_inet_types is False:
1283 conn.await_(self._disable_asyncpg_inet_codecs(conn))
1284 if super_connect is not None:
1285 super_connect(conn)
1286
1287 return connect
1288
1289 def get_driver_connection(self, connection):
1290 return connection._connection
1291
1292
1293dialect = PGDialect_asyncpg