1# dialects/postgresql/asyncpg.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors <see AUTHORS
3# file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: ignore-errors
8
9r"""
10.. dialect:: postgresql+asyncpg
11 :name: asyncpg
12 :dbapi: asyncpg
13 :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...]
14 :url: https://magicstack.github.io/asyncpg/
15
16The asyncpg dialect is SQLAlchemy's first Python asyncio dialect.
17
18Using a special asyncio mediation layer, the asyncpg dialect is usable
19as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
20extension package.
21
22This dialect should normally be used only with the
23:func:`_asyncio.create_async_engine` engine creation function::
24
25 from sqlalchemy.ext.asyncio import create_async_engine
26
27 engine = create_async_engine(
28 "postgresql+asyncpg://user:pass@hostname/dbname"
29 )
30
31.. versionadded:: 1.4
32
33.. note::
34
35 By default asyncpg does not decode the ``json`` and ``jsonb`` types and
36 returns them as strings. SQLAlchemy sets default type decoder for ``json``
37 and ``jsonb`` types using the python builtin ``json.loads`` function.
38 The json implementation used can be changed by setting the attribute
39 ``json_deserializer`` when creating the engine with
40 :func:`create_engine` or :func:`create_async_engine`.
41
42.. _asyncpg_multihost:
43
44Multihost Connections
45--------------------------
46
47The asyncpg dialect features support for multiple fallback hosts in the
48same way as that of the psycopg2 and psycopg dialects. The
49syntax is the same,
50using ``host=<host>:<port>`` combinations as additional query string arguments;
51however, there is no default port, so all hosts must have a complete port number
52present, otherwise an exception is raised::
53
54 engine = create_async_engine(
55 "postgresql+asyncpg://user:password@/dbname?host=HostA:5432&host=HostB:5432&host=HostC:5432"
56 )
57
58For complete background on this syntax, see :ref:`psycopg2_multi_host`.
59
60.. versionadded:: 2.0.18
61
62.. seealso::
63
64 :ref:`psycopg2_multi_host`
65
66.. _asyncpg_prepared_statement_cache:
67
68Prepared Statement Cache
69--------------------------
70
71The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()``
72for all statements. The prepared statement objects are cached after
73construction which appears to grant a 10% or more performance improvement for
74statement invocation. The cache is on a per-DBAPI connection basis, which
75means that the primary storage for prepared statements is within DBAPI
76connections pooled within the connection pool. The size of this cache
77defaults to 100 statements per DBAPI connection and may be adjusted using the
78``prepared_statement_cache_size`` DBAPI argument (note that while this argument
79is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the
80asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
81argument)::
82
83
84 engine = create_async_engine(
85 "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500"
86 )
87
88To disable the prepared statement cache, use a value of zero::
89
90 engine = create_async_engine(
91 "postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0"
92 )
93
94.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
95
96
97.. warning:: The ``asyncpg`` database driver necessarily uses caches for
98 PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes
99 such as ``ENUM`` objects are changed via DDL operations. Additionally,
100 prepared statements themselves which are optionally cached by SQLAlchemy's
101 driver as described above may also become "stale" when DDL has been emitted
102 to the PostgreSQL database which modifies the tables or other objects
103 involved in a particular prepared statement.
104
105 The SQLAlchemy asyncpg dialect will invalidate these caches within its local
106 process when statements that represent DDL are emitted on a local
107 connection, but this is only controllable within a single Python process /
108 database engine. If DDL changes are made from other database engines
109 and/or processes, a running application may encounter asyncpg exceptions
110 ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup
111 failed for type <oid>")`` if it refers to pooled database connections which
112 operated upon the previous structures. The SQLAlchemy asyncpg dialect will
113 recover from these error cases when the driver raises these exceptions by
114 clearing its internal caches as well as those of the asyncpg driver in
115 response to them, but cannot prevent them from being raised in the first
116 place if the cached prepared statement or asyncpg type caches have gone
117 stale, nor can it retry the statement as the PostgreSQL transaction is
118 invalidated when these errors occur.
119
120.. _asyncpg_prepared_statement_name:
121
122Prepared Statement Name with PGBouncer
123--------------------------------------
124
125By default, asyncpg enumerates prepared statements in numeric order, which
126can lead to errors if a name has already been taken for another prepared
127statement. This issue can arise if your application uses database proxies
128such as PgBouncer to handle connections. One possible workaround is to
129use dynamic prepared statement names, which asyncpg now supports through
130an optional ``name`` value for the statement name. This allows you to
131generate your own unique names that won't conflict with existing ones.
132To achieve this, you can provide a function that will be called every time
133a prepared statement is prepared::
134
135 from uuid import uuid4
136
137 engine = create_async_engine(
138 "postgresql+asyncpg://user:pass@somepgbouncer/dbname",
139 poolclass=NullPool,
140 connect_args={
141 "prepared_statement_name_func": lambda: f"__asyncpg_{uuid4()}__",
142 },
143 )
144
145.. seealso::
146
147 https://github.com/MagicStack/asyncpg/issues/837
148
149 https://github.com/sqlalchemy/sqlalchemy/issues/6467
150
151.. warning:: When using PGBouncer, to prevent a buildup of useless prepared statements in
152 your application, it's important to use the :class:`.NullPool` pool
153 class, and to configure PgBouncer to use `DISCARD <https://www.postgresql.org/docs/current/sql-discard.html>`_
154 when returning connections. The DISCARD command is used to release resources held by the db connection,
155 including prepared statements. Without proper setup, prepared statements can
156 accumulate quickly and cause performance issues.
157
158Disabling the PostgreSQL JIT to improve ENUM datatype handling
159---------------------------------------------------------------
160
161Asyncpg has an `issue <https://github.com/MagicStack/asyncpg/issues/727>`_ when
162using PostgreSQL ENUM datatypes, where upon the creation of new database
163connections, an expensive query may be emitted in order to retrieve metadata
164regarding custom types which has been shown to negatively affect performance.
165To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the
166client using this setting passed to :func:`_asyncio.create_async_engine`::
167
168 engine = create_async_engine(
169 "postgresql+asyncpg://user:password@localhost/tmp",
170 connect_args={"server_settings": {"jit": "off"}},
171 )
172
173.. seealso::
174
175 https://github.com/MagicStack/asyncpg/issues/727
176
177""" # noqa
178
179from __future__ import annotations
180
181from collections import deque
182import decimal
183import json as _py_json
184import re
185import time
186
187from . import json
188from . import ranges
189from .array import ARRAY as PGARRAY
190from .base import _DECIMAL_TYPES
191from .base import _FLOAT_TYPES
192from .base import _INT_TYPES
193from .base import ENUM
194from .base import INTERVAL
195from .base import OID
196from .base import PGCompiler
197from .base import PGDialect
198from .base import PGExecutionContext
199from .base import PGIdentifierPreparer
200from .base import REGCLASS
201from .base import REGCONFIG
202from .types import BIT
203from .types import BYTEA
204from .types import CITEXT
205from ... import exc
206from ... import pool
207from ... import util
208from ...engine import AdaptedConnection
209from ...engine import processors
210from ...sql import sqltypes
211from ...util.concurrency import asyncio
212from ...util.concurrency import await_fallback
213from ...util.concurrency import await_only
214
215
216class AsyncpgARRAY(PGARRAY):
217 render_bind_cast = True
218
219
220class AsyncpgString(sqltypes.String):
221 render_bind_cast = True
222
223
224class AsyncpgREGCONFIG(REGCONFIG):
225 render_bind_cast = True
226
227
228class AsyncpgTime(sqltypes.Time):
229 render_bind_cast = True
230
231
232class AsyncpgBit(BIT):
233 render_bind_cast = True
234
235
236class AsyncpgByteA(BYTEA):
237 render_bind_cast = True
238
239
240class AsyncpgDate(sqltypes.Date):
241 render_bind_cast = True
242
243
244class AsyncpgDateTime(sqltypes.DateTime):
245 render_bind_cast = True
246
247
248class AsyncpgBoolean(sqltypes.Boolean):
249 render_bind_cast = True
250
251
252class AsyncPgInterval(INTERVAL):
253 render_bind_cast = True
254
255 @classmethod
256 def adapt_emulated_to_native(cls, interval, **kw):
257 return AsyncPgInterval(precision=interval.second_precision)
258
259
260class AsyncPgEnum(ENUM):
261 render_bind_cast = True
262
263
264class AsyncpgInteger(sqltypes.Integer):
265 render_bind_cast = True
266
267
268class AsyncpgSmallInteger(sqltypes.SmallInteger):
269 render_bind_cast = True
270
271
272class AsyncpgBigInteger(sqltypes.BigInteger):
273 render_bind_cast = True
274
275
276class AsyncpgJSON(json.JSON):
277 def result_processor(self, dialect, coltype):
278 return None
279
280
281class AsyncpgJSONB(json.JSONB):
282 def result_processor(self, dialect, coltype):
283 return None
284
285
286class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
287 pass
288
289
290class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
291 __visit_name__ = "json_int_index"
292
293 render_bind_cast = True
294
295
296class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
297 __visit_name__ = "json_str_index"
298
299 render_bind_cast = True
300
301
302class AsyncpgJSONPathType(json.JSONPathType):
303 def bind_processor(self, dialect):
304 def process(value):
305 if isinstance(value, str):
306 # If it's already a string assume that it's in json path
307 # format. This allows using cast with json paths literals
308 return value
309 elif value:
310 tokens = [str(elem) for elem in value]
311 return tokens
312 else:
313 return []
314
315 return process
316
317
318class AsyncpgNumeric(sqltypes.Numeric):
319 render_bind_cast = True
320
321 def bind_processor(self, dialect):
322 return None
323
324 def result_processor(self, dialect, coltype):
325 if self.asdecimal:
326 if coltype in _FLOAT_TYPES:
327 return processors.to_decimal_processor_factory(
328 decimal.Decimal, self._effective_decimal_return_scale
329 )
330 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
331 # pg8000 returns Decimal natively for 1700
332 return None
333 else:
334 raise exc.InvalidRequestError(
335 "Unknown PG numeric type: %d" % coltype
336 )
337 else:
338 if coltype in _FLOAT_TYPES:
339 # pg8000 returns float natively for 701
340 return None
341 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
342 return processors.to_float
343 else:
344 raise exc.InvalidRequestError(
345 "Unknown PG numeric type: %d" % coltype
346 )
347
348
349class AsyncpgFloat(AsyncpgNumeric, sqltypes.Float):
350 __visit_name__ = "float"
351 render_bind_cast = True
352
353
354class AsyncpgREGCLASS(REGCLASS):
355 render_bind_cast = True
356
357
358class AsyncpgOID(OID):
359 render_bind_cast = True
360
361
362class AsyncpgCHAR(sqltypes.CHAR):
363 render_bind_cast = True
364
365
366class _AsyncpgRange(ranges.AbstractSingleRangeImpl):
367 def bind_processor(self, dialect):
368 asyncpg_Range = dialect.dbapi.asyncpg.Range
369
370 def to_range(value):
371 if isinstance(value, ranges.Range):
372 value = asyncpg_Range(
373 value.lower,
374 value.upper,
375 lower_inc=value.bounds[0] == "[",
376 upper_inc=value.bounds[1] == "]",
377 empty=value.empty,
378 )
379 return value
380
381 return to_range
382
383 def result_processor(self, dialect, coltype):
384 def to_range(value):
385 if value is not None:
386 empty = value.isempty
387 value = ranges.Range(
388 value.lower,
389 value.upper,
390 bounds=f"{'[' if empty or value.lower_inc else '('}" # type: ignore # noqa: E501
391 f"{']' if not empty and value.upper_inc else ')'}",
392 empty=empty,
393 )
394 return value
395
396 return to_range
397
398
399class _AsyncpgMultiRange(ranges.AbstractMultiRangeImpl):
400 def bind_processor(self, dialect):
401 asyncpg_Range = dialect.dbapi.asyncpg.Range
402
403 NoneType = type(None)
404
405 def to_range(value):
406 if isinstance(value, (str, NoneType)):
407 return value
408
409 def to_range(value):
410 if isinstance(value, ranges.Range):
411 value = asyncpg_Range(
412 value.lower,
413 value.upper,
414 lower_inc=value.bounds[0] == "[",
415 upper_inc=value.bounds[1] == "]",
416 empty=value.empty,
417 )
418 return value
419
420 return [to_range(element) for element in value]
421
422 return to_range
423
424 def result_processor(self, dialect, coltype):
425 def to_range_array(value):
426 def to_range(rvalue):
427 if rvalue is not None:
428 empty = rvalue.isempty
429 rvalue = ranges.Range(
430 rvalue.lower,
431 rvalue.upper,
432 bounds=f"{'[' if empty or rvalue.lower_inc else '('}" # type: ignore # noqa: E501
433 f"{']' if not empty and rvalue.upper_inc else ')'}",
434 empty=empty,
435 )
436 return rvalue
437
438 if value is not None:
439 value = ranges.MultiRange(to_range(elem) for elem in value)
440
441 return value
442
443 return to_range_array
444
445
446class PGExecutionContext_asyncpg(PGExecutionContext):
447 def handle_dbapi_exception(self, e):
448 if isinstance(
449 e,
450 (
451 self.dialect.dbapi.InvalidCachedStatementError,
452 self.dialect.dbapi.InternalServerError,
453 ),
454 ):
455 self.dialect._invalidate_schema_cache()
456
457 def pre_exec(self):
458 if self.isddl:
459 self.dialect._invalidate_schema_cache()
460
461 self.cursor._invalidate_schema_cache_asof = (
462 self.dialect._invalidate_schema_cache_asof
463 )
464
465 if not self.compiled:
466 return
467
468 def create_server_side_cursor(self):
469 return self._dbapi_connection.cursor(server_side=True)
470
471
472class PGCompiler_asyncpg(PGCompiler):
473 pass
474
475
476class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer):
477 pass
478
479
480class AsyncAdapt_asyncpg_cursor:
481 __slots__ = (
482 "_adapt_connection",
483 "_connection",
484 "_rows",
485 "description",
486 "arraysize",
487 "rowcount",
488 "_cursor",
489 "_invalidate_schema_cache_asof",
490 )
491
492 server_side = False
493
494 def __init__(self, adapt_connection):
495 self._adapt_connection = adapt_connection
496 self._connection = adapt_connection._connection
497 self._rows = deque()
498 self._cursor = None
499 self.description = None
500 self.arraysize = 1
501 self.rowcount = -1
502 self._invalidate_schema_cache_asof = 0
503
504 def close(self):
505 self._rows.clear()
506
507 def _handle_exception(self, error):
508 self._adapt_connection._handle_exception(error)
509
510 async def _prepare_and_execute(self, operation, parameters):
511 adapt_connection = self._adapt_connection
512
513 async with adapt_connection._execute_mutex:
514 if not adapt_connection._started:
515 await adapt_connection._start_transaction()
516
517 if parameters is None:
518 parameters = ()
519
520 try:
521 prepared_stmt, attributes = await adapt_connection._prepare(
522 operation, self._invalidate_schema_cache_asof
523 )
524
525 if attributes:
526 self.description = [
527 (
528 attr.name,
529 attr.type.oid,
530 None,
531 None,
532 None,
533 None,
534 None,
535 )
536 for attr in attributes
537 ]
538 else:
539 self.description = None
540
541 if self.server_side:
542 self._cursor = await prepared_stmt.cursor(*parameters)
543 self.rowcount = -1
544 else:
545 self._rows = deque(await prepared_stmt.fetch(*parameters))
546 status = prepared_stmt.get_statusmsg()
547
548 reg = re.match(
549 r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)",
550 status or "",
551 )
552 if reg:
553 self.rowcount = int(reg.group(1))
554 else:
555 self.rowcount = -1
556
557 except Exception as error:
558 self._handle_exception(error)
559
560 async def _executemany(self, operation, seq_of_parameters):
561 adapt_connection = self._adapt_connection
562
563 self.description = None
564 async with adapt_connection._execute_mutex:
565 await adapt_connection._check_type_cache_invalidation(
566 self._invalidate_schema_cache_asof
567 )
568
569 if not adapt_connection._started:
570 await adapt_connection._start_transaction()
571
572 try:
573 return await self._connection.executemany(
574 operation, seq_of_parameters
575 )
576 except Exception as error:
577 self._handle_exception(error)
578
579 def execute(self, operation, parameters=None):
580 self._adapt_connection.await_(
581 self._prepare_and_execute(operation, parameters)
582 )
583
584 def executemany(self, operation, seq_of_parameters):
585 return self._adapt_connection.await_(
586 self._executemany(operation, seq_of_parameters)
587 )
588
589 def setinputsizes(self, *inputsizes):
590 raise NotImplementedError()
591
592 def __iter__(self):
593 while self._rows:
594 yield self._rows.popleft()
595
596 def fetchone(self):
597 if self._rows:
598 return self._rows.popleft()
599 else:
600 return None
601
602 def fetchmany(self, size=None):
603 if size is None:
604 size = self.arraysize
605
606 rr = self._rows
607 return [rr.popleft() for _ in range(min(size, len(rr)))]
608
609 def fetchall(self):
610 retval = list(self._rows)
611 self._rows.clear()
612 return retval
613
614
615class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
616 server_side = True
617 __slots__ = ("_rowbuffer",)
618
619 def __init__(self, adapt_connection):
620 super().__init__(adapt_connection)
621 self._rowbuffer = deque()
622
623 def close(self):
624 self._cursor = None
625 self._rowbuffer.clear()
626
627 def _buffer_rows(self):
628 assert self._cursor is not None
629 new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
630 self._rowbuffer.extend(new_rows)
631
632 def __aiter__(self):
633 return self
634
635 async def __anext__(self):
636 while True:
637 while self._rowbuffer:
638 yield self._rowbuffer.popleft()
639
640 self._buffer_rows()
641 if not self._rowbuffer:
642 break
643
644 def fetchone(self):
645 if not self._rowbuffer:
646 self._buffer_rows()
647 if not self._rowbuffer:
648 return None
649 return self._rowbuffer.popleft()
650
651 def fetchmany(self, size=None):
652 if size is None:
653 return self.fetchall()
654
655 if not self._rowbuffer:
656 self._buffer_rows()
657
658 assert self._cursor is not None
659 rb = self._rowbuffer
660 lb = len(rb)
661 if size > lb:
662 rb.extend(
663 self._adapt_connection.await_(self._cursor.fetch(size - lb))
664 )
665
666 return [rb.popleft() for _ in range(min(size, len(rb)))]
667
668 def fetchall(self):
669 ret = list(self._rowbuffer)
670 ret.extend(self._adapt_connection.await_(self._all()))
671 self._rowbuffer.clear()
672 return ret
673
674 async def _all(self):
675 rows = []
676
677 # TODO: looks like we have to hand-roll some kind of batching here.
678 # hardcoding for the moment but this should be improved.
679 while True:
680 batch = await self._cursor.fetch(1000)
681 if batch:
682 rows.extend(batch)
683 continue
684 else:
685 break
686 return rows
687
688 def executemany(self, operation, seq_of_parameters):
689 raise NotImplementedError(
690 "server side cursor doesn't support executemany yet"
691 )
692
693
694class AsyncAdapt_asyncpg_connection(AdaptedConnection):
695 __slots__ = (
696 "dbapi",
697 "isolation_level",
698 "_isolation_setting",
699 "readonly",
700 "deferrable",
701 "_transaction",
702 "_started",
703 "_prepared_statement_cache",
704 "_prepared_statement_name_func",
705 "_invalidate_schema_cache_asof",
706 "_execute_mutex",
707 )
708
709 await_ = staticmethod(await_only)
710
711 def __init__(
712 self,
713 dbapi,
714 connection,
715 prepared_statement_cache_size=100,
716 prepared_statement_name_func=None,
717 ):
718 self.dbapi = dbapi
719 self._connection = connection
720 self.isolation_level = self._isolation_setting = None
721 self.readonly = False
722 self.deferrable = False
723 self._transaction = None
724 self._started = False
725 self._invalidate_schema_cache_asof = time.time()
726 self._execute_mutex = asyncio.Lock()
727
728 if prepared_statement_cache_size:
729 self._prepared_statement_cache = util.LRUCache(
730 prepared_statement_cache_size
731 )
732 else:
733 self._prepared_statement_cache = None
734
735 if prepared_statement_name_func:
736 self._prepared_statement_name_func = prepared_statement_name_func
737 else:
738 self._prepared_statement_name_func = self._default_name_func
739
740 async def _check_type_cache_invalidation(self, invalidate_timestamp):
741 if invalidate_timestamp > self._invalidate_schema_cache_asof:
742 await self._connection.reload_schema_state()
743 self._invalidate_schema_cache_asof = invalidate_timestamp
744
745 async def _prepare(self, operation, invalidate_timestamp):
746 await self._check_type_cache_invalidation(invalidate_timestamp)
747
748 cache = self._prepared_statement_cache
749 if cache is None:
750 prepared_stmt = await self._connection.prepare(
751 operation, name=self._prepared_statement_name_func()
752 )
753 attributes = prepared_stmt.get_attributes()
754 return prepared_stmt, attributes
755
756 # asyncpg uses a type cache for the "attributes" which seems to go
757 # stale independently of the PreparedStatement itself, so place that
758 # collection in the cache as well.
759 if operation in cache:
760 prepared_stmt, attributes, cached_timestamp = cache[operation]
761
762 # preparedstatements themselves also go stale for certain DDL
763 # changes such as size of a VARCHAR changing, so there is also
764 # a cross-connection invalidation timestamp
765 if cached_timestamp > invalidate_timestamp:
766 return prepared_stmt, attributes
767
768 prepared_stmt = await self._connection.prepare(
769 operation, name=self._prepared_statement_name_func()
770 )
771 attributes = prepared_stmt.get_attributes()
772 cache[operation] = (prepared_stmt, attributes, time.time())
773
774 return prepared_stmt, attributes
775
776 def _handle_exception(self, error):
777 if self._connection.is_closed():
778 self._transaction = None
779 self._started = False
780
781 if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
782 exception_mapping = self.dbapi._asyncpg_error_translate
783
784 for super_ in type(error).__mro__:
785 if super_ in exception_mapping:
786 translated_error = exception_mapping[super_](
787 "%s: %s" % (type(error), error)
788 )
789 translated_error.pgcode = translated_error.sqlstate = (
790 getattr(error, "sqlstate", None)
791 )
792 raise translated_error from error
793 else:
794 raise error
795 else:
796 raise error
797
798 @property
799 def autocommit(self):
800 return self.isolation_level == "autocommit"
801
802 @autocommit.setter
803 def autocommit(self, value):
804 if value:
805 self.isolation_level = "autocommit"
806 else:
807 self.isolation_level = self._isolation_setting
808
809 def ping(self):
810 try:
811 _ = self.await_(self._async_ping())
812 except Exception as error:
813 self._handle_exception(error)
814
815 async def _async_ping(self):
816 if self._transaction is None and self.isolation_level != "autocommit":
817 # create a tranasction explicitly to support pgbouncer
818 # transaction mode. See #10226
819 tr = self._connection.transaction()
820 await tr.start()
821 try:
822 await self._connection.fetchrow(";")
823 finally:
824 await tr.rollback()
825 else:
826 await self._connection.fetchrow(";")
827
828 def set_isolation_level(self, level):
829 if self._started:
830 self.rollback()
831 self.isolation_level = self._isolation_setting = level
832
833 async def _start_transaction(self):
834 if self.isolation_level == "autocommit":
835 return
836
837 try:
838 self._transaction = self._connection.transaction(
839 isolation=self.isolation_level,
840 readonly=self.readonly,
841 deferrable=self.deferrable,
842 )
843 await self._transaction.start()
844 except Exception as error:
845 self._handle_exception(error)
846 else:
847 self._started = True
848
849 def cursor(self, server_side=False):
850 if server_side:
851 return AsyncAdapt_asyncpg_ss_cursor(self)
852 else:
853 return AsyncAdapt_asyncpg_cursor(self)
854
855 async def _rollback_and_discard(self):
856 try:
857 await self._transaction.rollback()
858 finally:
859 # if asyncpg .rollback() was actually called, then whether or
860 # not it raised or succeeded, the transation is done, discard it
861 self._transaction = None
862 self._started = False
863
864 async def _commit_and_discard(self):
865 try:
866 await self._transaction.commit()
867 finally:
868 # if asyncpg .commit() was actually called, then whether or
869 # not it raised or succeeded, the transation is done, discard it
870 self._transaction = None
871 self._started = False
872
873 def rollback(self):
874 if self._started:
875 try:
876 self.await_(self._rollback_and_discard())
877 self._transaction = None
878 self._started = False
879 except Exception as error:
880 # don't dereference asyncpg transaction if we didn't
881 # actually try to call rollback() on it
882 self._handle_exception(error)
883
884 def commit(self):
885 if self._started:
886 try:
887 self.await_(self._commit_and_discard())
888 self._transaction = None
889 self._started = False
890 except Exception as error:
891 # don't dereference asyncpg transaction if we didn't
892 # actually try to call commit() on it
893 self._handle_exception(error)
894
895 def close(self):
896 self.rollback()
897
898 self.await_(self._connection.close())
899
900 def terminate(self):
901 if util.concurrency.in_greenlet():
902 # in a greenlet; this is the connection was invalidated
903 # case.
904 try:
905 # try to gracefully close; see #10717
906 # timeout added in asyncpg 0.14.0 December 2017
907 self.await_(asyncio.shield(self._connection.close(timeout=2)))
908 except (
909 asyncio.TimeoutError,
910 asyncio.CancelledError,
911 OSError,
912 self.dbapi.asyncpg.PostgresError,
913 ):
914 # in the case where we are recycling an old connection
915 # that may have already been disconnected, close() will
916 # fail with the above timeout. in this case, terminate
917 # the connection without any further waiting.
918 # see issue #8419
919 self._connection.terminate()
920 else:
921 # not in a greenlet; this is the gc cleanup case
922 self._connection.terminate()
923 self._started = False
924
925 @staticmethod
926 def _default_name_func():
927 return None
928
929
930class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
931 __slots__ = ()
932
933 await_ = staticmethod(await_fallback)
934
935
936class AsyncAdapt_asyncpg_dbapi:
937 def __init__(self, asyncpg):
938 self.asyncpg = asyncpg
939 self.paramstyle = "numeric_dollar"
940
941 def connect(self, *arg, **kw):
942 async_fallback = kw.pop("async_fallback", False)
943 creator_fn = kw.pop("async_creator_fn", self.asyncpg.connect)
944 prepared_statement_cache_size = kw.pop(
945 "prepared_statement_cache_size", 100
946 )
947 prepared_statement_name_func = kw.pop(
948 "prepared_statement_name_func", None
949 )
950
951 if util.asbool(async_fallback):
952 return AsyncAdaptFallback_asyncpg_connection(
953 self,
954 await_fallback(creator_fn(*arg, **kw)),
955 prepared_statement_cache_size=prepared_statement_cache_size,
956 prepared_statement_name_func=prepared_statement_name_func,
957 )
958 else:
959 return AsyncAdapt_asyncpg_connection(
960 self,
961 await_only(creator_fn(*arg, **kw)),
962 prepared_statement_cache_size=prepared_statement_cache_size,
963 prepared_statement_name_func=prepared_statement_name_func,
964 )
965
966 class Error(Exception):
967 pass
968
969 class Warning(Exception): # noqa
970 pass
971
972 class InterfaceError(Error):
973 pass
974
975 class DatabaseError(Error):
976 pass
977
978 class InternalError(DatabaseError):
979 pass
980
981 class OperationalError(DatabaseError):
982 pass
983
984 class ProgrammingError(DatabaseError):
985 pass
986
987 class IntegrityError(DatabaseError):
988 pass
989
990 class DataError(DatabaseError):
991 pass
992
993 class NotSupportedError(DatabaseError):
994 pass
995
996 class InternalServerError(InternalError):
997 pass
998
999 class InvalidCachedStatementError(NotSupportedError):
1000 def __init__(self, message):
1001 super().__init__(
1002 message + " (SQLAlchemy asyncpg dialect will now invalidate "
1003 "all prepared caches in response to this exception)",
1004 )
1005
1006 # pep-249 datatype placeholders. As of SQLAlchemy 2.0 these aren't
1007 # used, however the test suite looks for these in a few cases.
1008 STRING = util.symbol("STRING")
1009 NUMBER = util.symbol("NUMBER")
1010 DATETIME = util.symbol("DATETIME")
1011
1012 @util.memoized_property
1013 def _asyncpg_error_translate(self):
1014 import asyncpg
1015
1016 return {
1017 asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501
1018 asyncpg.exceptions.PostgresError: self.Error,
1019 asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError,
1020 asyncpg.exceptions.InterfaceError: self.InterfaceError,
1021 asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501
1022 asyncpg.exceptions.InternalServerError: self.InternalServerError,
1023 }
1024
1025 def Binary(self, value):
1026 return value
1027
1028
1029class PGDialect_asyncpg(PGDialect):
1030 driver = "asyncpg"
1031 supports_statement_cache = True
1032
1033 supports_server_side_cursors = True
1034
1035 render_bind_cast = True
1036 has_terminate = True
1037
1038 default_paramstyle = "numeric_dollar"
1039 supports_sane_multi_rowcount = False
1040 execution_ctx_cls = PGExecutionContext_asyncpg
1041 statement_compiler = PGCompiler_asyncpg
1042 preparer = PGIdentifierPreparer_asyncpg
1043
1044 colspecs = util.update_copy(
1045 PGDialect.colspecs,
1046 {
1047 sqltypes.String: AsyncpgString,
1048 sqltypes.ARRAY: AsyncpgARRAY,
1049 BIT: AsyncpgBit,
1050 CITEXT: CITEXT,
1051 REGCONFIG: AsyncpgREGCONFIG,
1052 sqltypes.Time: AsyncpgTime,
1053 sqltypes.Date: AsyncpgDate,
1054 sqltypes.DateTime: AsyncpgDateTime,
1055 sqltypes.Interval: AsyncPgInterval,
1056 INTERVAL: AsyncPgInterval,
1057 sqltypes.Boolean: AsyncpgBoolean,
1058 sqltypes.Integer: AsyncpgInteger,
1059 sqltypes.SmallInteger: AsyncpgSmallInteger,
1060 sqltypes.BigInteger: AsyncpgBigInteger,
1061 sqltypes.Numeric: AsyncpgNumeric,
1062 sqltypes.Float: AsyncpgFloat,
1063 sqltypes.JSON: AsyncpgJSON,
1064 sqltypes.LargeBinary: AsyncpgByteA,
1065 json.JSONB: AsyncpgJSONB,
1066 sqltypes.JSON.JSONPathType: AsyncpgJSONPathType,
1067 sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType,
1068 sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType,
1069 sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType,
1070 sqltypes.Enum: AsyncPgEnum,
1071 OID: AsyncpgOID,
1072 REGCLASS: AsyncpgREGCLASS,
1073 sqltypes.CHAR: AsyncpgCHAR,
1074 ranges.AbstractSingleRange: _AsyncpgRange,
1075 ranges.AbstractMultiRange: _AsyncpgMultiRange,
1076 },
1077 )
1078 is_async = True
1079 _invalidate_schema_cache_asof = 0
1080
1081 def _invalidate_schema_cache(self):
1082 self._invalidate_schema_cache_asof = time.time()
1083
1084 @util.memoized_property
1085 def _dbapi_version(self):
1086 if self.dbapi and hasattr(self.dbapi, "__version__"):
1087 return tuple(
1088 [
1089 int(x)
1090 for x in re.findall(
1091 r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
1092 )
1093 ]
1094 )
1095 else:
1096 return (99, 99, 99)
1097
1098 @classmethod
1099 def import_dbapi(cls):
1100 return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg"))
1101
1102 @util.memoized_property
1103 def _isolation_lookup(self):
1104 return {
1105 "AUTOCOMMIT": "autocommit",
1106 "READ COMMITTED": "read_committed",
1107 "REPEATABLE READ": "repeatable_read",
1108 "SERIALIZABLE": "serializable",
1109 }
1110
1111 def get_isolation_level_values(self, dbapi_connection):
1112 return list(self._isolation_lookup)
1113
1114 def set_isolation_level(self, dbapi_connection, level):
1115 dbapi_connection.set_isolation_level(self._isolation_lookup[level])
1116
1117 def set_readonly(self, connection, value):
1118 connection.readonly = value
1119
1120 def get_readonly(self, connection):
1121 return connection.readonly
1122
1123 def set_deferrable(self, connection, value):
1124 connection.deferrable = value
1125
1126 def get_deferrable(self, connection):
1127 return connection.deferrable
1128
1129 def do_terminate(self, dbapi_connection) -> None:
1130 dbapi_connection.terminate()
1131
1132 def create_connect_args(self, url):
1133 opts = url.translate_connect_args(username="user")
1134 multihosts, multiports = self._split_multihost_from_url(url)
1135
1136 opts.update(url.query)
1137
1138 if multihosts:
1139 assert multiports
1140 if len(multihosts) == 1:
1141 opts["host"] = multihosts[0]
1142 if multiports[0] is not None:
1143 opts["port"] = multiports[0]
1144 elif not all(multihosts):
1145 raise exc.ArgumentError(
1146 "All hosts are required to be present"
1147 " for asyncpg multiple host URL"
1148 )
1149 elif not all(multiports):
1150 raise exc.ArgumentError(
1151 "All ports are required to be present"
1152 " for asyncpg multiple host URL"
1153 )
1154 else:
1155 opts["host"] = list(multihosts)
1156 opts["port"] = list(multiports)
1157 else:
1158 util.coerce_kw_type(opts, "port", int)
1159 util.coerce_kw_type(opts, "prepared_statement_cache_size", int)
1160 return ([], opts)
1161
1162 def do_ping(self, dbapi_connection):
1163 dbapi_connection.ping()
1164 return True
1165
1166 @classmethod
1167 def get_pool_class(cls, url):
1168 async_fallback = url.query.get("async_fallback", False)
1169
1170 if util.asbool(async_fallback):
1171 return pool.FallbackAsyncAdaptedQueuePool
1172 else:
1173 return pool.AsyncAdaptedQueuePool
1174
1175 def is_disconnect(self, e, connection, cursor):
1176 if connection:
1177 return connection._connection.is_closed()
1178 else:
1179 return isinstance(
1180 e, self.dbapi.InterfaceError
1181 ) and "connection is closed" in str(e)
1182
1183 async def setup_asyncpg_json_codec(self, conn):
1184 """set up JSON codec for asyncpg.
1185
1186 This occurs for all new connections and
1187 can be overridden by third party dialects.
1188
1189 .. versionadded:: 1.4.27
1190
1191 """
1192
1193 asyncpg_connection = conn._connection
1194 deserializer = self._json_deserializer or _py_json.loads
1195
1196 def _json_decoder(bin_value):
1197 return deserializer(bin_value.decode())
1198
1199 await asyncpg_connection.set_type_codec(
1200 "json",
1201 encoder=str.encode,
1202 decoder=_json_decoder,
1203 schema="pg_catalog",
1204 format="binary",
1205 )
1206
1207 async def setup_asyncpg_jsonb_codec(self, conn):
1208 """set up JSONB codec for asyncpg.
1209
1210 This occurs for all new connections and
1211 can be overridden by third party dialects.
1212
1213 .. versionadded:: 1.4.27
1214
1215 """
1216
1217 asyncpg_connection = conn._connection
1218 deserializer = self._json_deserializer or _py_json.loads
1219
1220 def _jsonb_encoder(str_value):
1221 # \x01 is the prefix for jsonb used by PostgreSQL.
1222 # asyncpg requires it when format='binary'
1223 return b"\x01" + str_value.encode()
1224
1225 deserializer = self._json_deserializer or _py_json.loads
1226
1227 def _jsonb_decoder(bin_value):
1228 # the byte is the \x01 prefix for jsonb used by PostgreSQL.
1229 # asyncpg returns it when format='binary'
1230 return deserializer(bin_value[1:].decode())
1231
1232 await asyncpg_connection.set_type_codec(
1233 "jsonb",
1234 encoder=_jsonb_encoder,
1235 decoder=_jsonb_decoder,
1236 schema="pg_catalog",
1237 format="binary",
1238 )
1239
1240 async def _disable_asyncpg_inet_codecs(self, conn):
1241 asyncpg_connection = conn._connection
1242
1243 await asyncpg_connection.set_type_codec(
1244 "inet",
1245 encoder=lambda s: s,
1246 decoder=lambda s: s,
1247 schema="pg_catalog",
1248 format="text",
1249 )
1250
1251 await asyncpg_connection.set_type_codec(
1252 "cidr",
1253 encoder=lambda s: s,
1254 decoder=lambda s: s,
1255 schema="pg_catalog",
1256 format="text",
1257 )
1258
1259 def on_connect(self):
1260 """on_connect for asyncpg
1261
1262 A major component of this for asyncpg is to set up type decoders at the
1263 asyncpg level.
1264
1265 See https://github.com/MagicStack/asyncpg/issues/623 for
1266 notes on JSON/JSONB implementation.
1267
1268 """
1269
1270 super_connect = super().on_connect()
1271
1272 def connect(conn):
1273 conn.await_(self.setup_asyncpg_json_codec(conn))
1274 conn.await_(self.setup_asyncpg_jsonb_codec(conn))
1275
1276 if self._native_inet_types is False:
1277 conn.await_(self._disable_asyncpg_inet_codecs(conn))
1278 if super_connect is not None:
1279 super_connect(conn)
1280
1281 return connect
1282
1283 def get_driver_connection(self, connection):
1284 return connection._connection
1285
1286
1287dialect = PGDialect_asyncpg