Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/sqlalchemy/dialects/postgresql/asyncpg.py: 42%
534 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1# postgresql/asyncpg.py
2# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors <see AUTHORS
3# file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7r"""
8.. dialect:: postgresql+asyncpg
9 :name: asyncpg
10 :dbapi: asyncpg
11 :connectstring: postgresql+asyncpg://user:password@host:port/dbname[?key=value&key=value...]
12 :url: https://magicstack.github.io/asyncpg/
14The asyncpg dialect is SQLAlchemy's first Python asyncio dialect.
16Using a special asyncio mediation layer, the asyncpg dialect is usable
17as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
18extension package.
20This dialect should normally be used only with the
21:func:`_asyncio.create_async_engine` engine creation function::
23 from sqlalchemy.ext.asyncio import create_async_engine
24 engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname")
26The dialect can also be run as a "synchronous" dialect within the
27:func:`_sa.create_engine` function, which will pass "await" calls into
28an ad-hoc event loop. This mode of operation is of **limited use**
29and is for special testing scenarios only. The mode can be enabled by
30adding the SQLAlchemy-specific flag ``async_fallback`` to the URL
31in conjunction with :func:`_sa.create_engine`::
33 # for testing purposes only; do not use in production!
34 engine = create_engine("postgresql+asyncpg://user:pass@hostname/dbname?async_fallback=true")
37.. versionadded:: 1.4
39.. note::
41 By default asyncpg does not decode the ``json`` and ``jsonb`` types and
42 returns them as strings. SQLAlchemy sets default type decoder for ``json``
43 and ``jsonb`` types using the python builtin ``json.loads`` function.
44 The json implementation used can be changed by setting the attribute
45 ``json_deserializer`` when creating the engine with
46 :func:`create_engine` or :func:`create_async_engine`.
49.. _asyncpg_prepared_statement_cache:
51Prepared Statement Cache
52--------------------------
54The asyncpg SQLAlchemy dialect makes use of ``asyncpg.connection.prepare()``
55for all statements. The prepared statement objects are cached after
56construction which appears to grant a 10% or more performance improvement for
57statement invocation. The cache is on a per-DBAPI connection basis, which
58means that the primary storage for prepared statements is within DBAPI
59connections pooled within the connection pool. The size of this cache
60defaults to 100 statements per DBAPI connection and may be adjusted using the
61``prepared_statement_cache_size`` DBAPI argument (note that while this argument
62is implemented by SQLAlchemy, it is part of the DBAPI emulation portion of the
63asyncpg dialect, therefore is handled as a DBAPI argument, not a dialect
64argument)::
67 engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=500")
69To disable the prepared statement cache, use a value of zero::
71 engine = create_async_engine("postgresql+asyncpg://user:pass@hostname/dbname?prepared_statement_cache_size=0")
73.. versionadded:: 1.4.0b2 Added ``prepared_statement_cache_size`` for asyncpg.
76.. warning:: The ``asyncpg`` database driver necessarily uses caches for
77 PostgreSQL type OIDs, which become stale when custom PostgreSQL datatypes
78 such as ``ENUM`` objects are changed via DDL operations. Additionally,
79 prepared statements themselves which are optionally cached by SQLAlchemy's
80 driver as described above may also become "stale" when DDL has been emitted
81 to the PostgreSQL database which modifies the tables or other objects
82 involved in a particular prepared statement.
84 The SQLAlchemy asyncpg dialect will invalidate these caches within its local
85 process when statements that represent DDL are emitted on a local
86 connection, but this is only controllable within a single Python process /
87 database engine. If DDL changes are made from other database engines
88 and/or processes, a running application may encounter asyncpg exceptions
89 ``InvalidCachedStatementError`` and/or ``InternalServerError("cache lookup
90 failed for type <oid>")`` if it refers to pooled database connections which
91 operated upon the previous structures. The SQLAlchemy asyncpg dialect will
92 recover from these error cases when the driver raises these exceptions by
93 clearing its internal caches as well as those of the asyncpg driver in
94 response to them, but cannot prevent them from being raised in the first
95 place if the cached prepared statement or asyncpg type caches have gone
96 stale, nor can it retry the statement as the PostgreSQL transaction is
97 invalidated when these errors occur.
99Disabling the PostgreSQL JIT to improve ENUM datatype handling
100---------------------------------------------------------------
102Asyncpg has an `issue <https://github.com/MagicStack/asyncpg/issues/727>`_ when
103using PostgreSQL ENUM datatypes, where upon the creation of new database
104connections, an expensive query may be emitted in order to retrieve metadata
105regarding custom types which has been shown to negatively affect performance.
106To mitigate this issue, the PostgreSQL "jit" setting may be disabled from the
107client using this setting passed to :func:`_asyncio.create_async_engine`::
109 engine = create_async_engine(
110 "postgresql+asyncpg://user:password@localhost/tmp",
111 connect_args={"server_settings": {"jit": "off"}},
112 )
114.. seealso::
116 https://github.com/MagicStack/asyncpg/issues/727
118""" # noqa
120import collections
121import decimal
122import json as _py_json
123import re
124import time
126from . import json
127from .base import _DECIMAL_TYPES
128from .base import _FLOAT_TYPES
129from .base import _INT_TYPES
130from .base import ENUM
131from .base import INTERVAL
132from .base import OID
133from .base import PGCompiler
134from .base import PGDialect
135from .base import PGExecutionContext
136from .base import PGIdentifierPreparer
137from .base import REGCLASS
138from .base import UUID
139from ... import exc
140from ... import pool
141from ... import processors
142from ... import util
143from ...engine import AdaptedConnection
144from ...sql import sqltypes
145from ...util.concurrency import asyncio
146from ...util.concurrency import await_fallback
147from ...util.concurrency import await_only
150try:
151 from uuid import UUID as _python_UUID # noqa
152except ImportError:
153 _python_UUID = None
156class AsyncpgTime(sqltypes.Time):
157 def get_dbapi_type(self, dbapi):
158 if self.timezone:
159 return dbapi.TIME_W_TZ
160 else:
161 return dbapi.TIME
164class AsyncpgDate(sqltypes.Date):
165 def get_dbapi_type(self, dbapi):
166 return dbapi.DATE
169class AsyncpgDateTime(sqltypes.DateTime):
170 def get_dbapi_type(self, dbapi):
171 if self.timezone:
172 return dbapi.TIMESTAMP_W_TZ
173 else:
174 return dbapi.TIMESTAMP
177class AsyncpgBoolean(sqltypes.Boolean):
178 def get_dbapi_type(self, dbapi):
179 return dbapi.BOOLEAN
182class AsyncPgInterval(INTERVAL):
183 def get_dbapi_type(self, dbapi):
184 return dbapi.INTERVAL
186 @classmethod
187 def adapt_emulated_to_native(cls, interval, **kw):
189 return AsyncPgInterval(precision=interval.second_precision)
192class AsyncPgEnum(ENUM):
193 def get_dbapi_type(self, dbapi):
194 return dbapi.ENUM
197class AsyncpgInteger(sqltypes.Integer):
198 def get_dbapi_type(self, dbapi):
199 return dbapi.INTEGER
202class AsyncpgBigInteger(sqltypes.BigInteger):
203 def get_dbapi_type(self, dbapi):
204 return dbapi.BIGINTEGER
207class AsyncpgJSON(json.JSON):
208 def get_dbapi_type(self, dbapi):
209 return dbapi.JSON
211 def result_processor(self, dialect, coltype):
212 return None
215class AsyncpgJSONB(json.JSONB):
216 def get_dbapi_type(self, dbapi):
217 return dbapi.JSONB
219 def result_processor(self, dialect, coltype):
220 return None
223class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
224 def get_dbapi_type(self, dbapi):
225 raise NotImplementedError("should not be here")
228class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
229 def get_dbapi_type(self, dbapi):
230 return dbapi.INTEGER
233class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
234 def get_dbapi_type(self, dbapi):
235 return dbapi.STRING
238class AsyncpgJSONPathType(json.JSONPathType):
239 def bind_processor(self, dialect):
240 def process(value):
241 assert isinstance(value, util.collections_abc.Sequence)
242 tokens = [util.text_type(elem) for elem in value]
243 return tokens
245 return process
248class AsyncpgUUID(UUID):
249 def get_dbapi_type(self, dbapi):
250 return dbapi.UUID
252 def bind_processor(self, dialect):
253 if not self.as_uuid and dialect.use_native_uuid:
255 def process(value):
256 if value is not None:
257 value = _python_UUID(value)
258 return value
260 return process
262 def result_processor(self, dialect, coltype):
263 if not self.as_uuid and dialect.use_native_uuid:
265 def process(value):
266 if value is not None:
267 value = str(value)
268 return value
270 return process
273class AsyncpgNumeric(sqltypes.Numeric):
274 def get_dbapi_type(self, dbapi):
275 return dbapi.NUMBER
277 def bind_processor(self, dialect):
278 return None
280 def result_processor(self, dialect, coltype):
281 if self.asdecimal:
282 if coltype in _FLOAT_TYPES:
283 return processors.to_decimal_processor_factory(
284 decimal.Decimal, self._effective_decimal_return_scale
285 )
286 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
287 # pg8000 returns Decimal natively for 1700
288 return None
289 else:
290 raise exc.InvalidRequestError(
291 "Unknown PG numeric type: %d" % coltype
292 )
293 else:
294 if coltype in _FLOAT_TYPES:
295 # pg8000 returns float natively for 701
296 return None
297 elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
298 return processors.to_float
299 else:
300 raise exc.InvalidRequestError(
301 "Unknown PG numeric type: %d" % coltype
302 )
305class AsyncpgFloat(AsyncpgNumeric):
306 def get_dbapi_type(self, dbapi):
307 return dbapi.FLOAT
310class AsyncpgREGCLASS(REGCLASS):
311 def get_dbapi_type(self, dbapi):
312 return dbapi.STRING
315class AsyncpgOID(OID):
316 def get_dbapi_type(self, dbapi):
317 return dbapi.INTEGER
320class PGExecutionContext_asyncpg(PGExecutionContext):
321 def handle_dbapi_exception(self, e):
322 if isinstance(
323 e,
324 (
325 self.dialect.dbapi.InvalidCachedStatementError,
326 self.dialect.dbapi.InternalServerError,
327 ),
328 ):
329 self.dialect._invalidate_schema_cache()
331 def pre_exec(self):
332 if self.isddl:
333 self.dialect._invalidate_schema_cache()
335 self.cursor._invalidate_schema_cache_asof = (
336 self.dialect._invalidate_schema_cache_asof
337 )
339 if not self.compiled:
340 return
342 # we have to exclude ENUM because "enum" not really a "type"
343 # we can cast to, it has to be the name of the type itself.
344 # for now we just omit it from casting
345 self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM}
347 def create_server_side_cursor(self):
348 return self._dbapi_connection.cursor(server_side=True)
351class PGCompiler_asyncpg(PGCompiler):
352 pass
355class PGIdentifierPreparer_asyncpg(PGIdentifierPreparer):
356 pass
359class AsyncAdapt_asyncpg_cursor:
360 __slots__ = (
361 "_adapt_connection",
362 "_connection",
363 "_rows",
364 "description",
365 "arraysize",
366 "rowcount",
367 "_inputsizes",
368 "_cursor",
369 "_invalidate_schema_cache_asof",
370 )
372 server_side = False
374 def __init__(self, adapt_connection):
375 self._adapt_connection = adapt_connection
376 self._connection = adapt_connection._connection
377 self._rows = []
378 self._cursor = None
379 self.description = None
380 self.arraysize = 1
381 self.rowcount = -1
382 self._inputsizes = None
383 self._invalidate_schema_cache_asof = 0
385 def close(self):
386 self._rows[:] = []
388 def _handle_exception(self, error):
389 self._adapt_connection._handle_exception(error)
391 def _parameter_placeholders(self, params):
392 if not self._inputsizes:
393 return tuple("$%d" % idx for idx, _ in enumerate(params, 1))
394 else:
395 return tuple(
396 "$%d::%s" % (idx, typ) if typ else "$%d" % idx
397 for idx, typ in enumerate(
398 (_pg_types.get(typ) for typ in self._inputsizes), 1
399 )
400 )
402 async def _prepare_and_execute(self, operation, parameters):
403 adapt_connection = self._adapt_connection
405 async with adapt_connection._execute_mutex:
407 if not adapt_connection._started:
408 await adapt_connection._start_transaction()
410 if parameters is not None:
411 operation = operation % self._parameter_placeholders(
412 parameters
413 )
414 else:
415 parameters = ()
417 try:
418 prepared_stmt, attributes = await adapt_connection._prepare(
419 operation, self._invalidate_schema_cache_asof
420 )
422 if attributes:
423 self.description = [
424 (
425 attr.name,
426 attr.type.oid,
427 None,
428 None,
429 None,
430 None,
431 None,
432 )
433 for attr in attributes
434 ]
435 else:
436 self.description = None
438 if self.server_side:
439 self._cursor = await prepared_stmt.cursor(*parameters)
440 self.rowcount = -1
441 else:
442 self._rows = await prepared_stmt.fetch(*parameters)
443 status = prepared_stmt.get_statusmsg()
445 reg = re.match(
446 r"(?:SELECT|UPDATE|DELETE|INSERT \d+) (\d+)", status
447 )
448 if reg:
449 self.rowcount = int(reg.group(1))
450 else:
451 self.rowcount = -1
453 except Exception as error:
454 self._handle_exception(error)
456 async def _executemany(self, operation, seq_of_parameters):
457 adapt_connection = self._adapt_connection
459 async with adapt_connection._execute_mutex:
460 await adapt_connection._check_type_cache_invalidation(
461 self._invalidate_schema_cache_asof
462 )
464 if not adapt_connection._started:
465 await adapt_connection._start_transaction()
467 operation = operation % self._parameter_placeholders(
468 seq_of_parameters[0]
469 )
471 try:
472 return await self._connection.executemany(
473 operation, seq_of_parameters
474 )
475 except Exception as error:
476 self._handle_exception(error)
478 def execute(self, operation, parameters=None):
479 self._adapt_connection.await_(
480 self._prepare_and_execute(operation, parameters)
481 )
483 def executemany(self, operation, seq_of_parameters):
484 return self._adapt_connection.await_(
485 self._executemany(operation, seq_of_parameters)
486 )
488 def setinputsizes(self, *inputsizes):
489 self._inputsizes = inputsizes
491 def __iter__(self):
492 while self._rows:
493 yield self._rows.pop(0)
495 def fetchone(self):
496 if self._rows:
497 return self._rows.pop(0)
498 else:
499 return None
501 def fetchmany(self, size=None):
502 if size is None:
503 size = self.arraysize
505 retval = self._rows[0:size]
506 self._rows[:] = self._rows[size:]
507 return retval
509 def fetchall(self):
510 retval = self._rows[:]
511 self._rows[:] = []
512 return retval
515class AsyncAdapt_asyncpg_ss_cursor(AsyncAdapt_asyncpg_cursor):
517 server_side = True
518 __slots__ = ("_rowbuffer",)
520 def __init__(self, adapt_connection):
521 super(AsyncAdapt_asyncpg_ss_cursor, self).__init__(adapt_connection)
522 self._rowbuffer = None
524 def close(self):
525 self._cursor = None
526 self._rowbuffer = None
528 def _buffer_rows(self):
529 new_rows = self._adapt_connection.await_(self._cursor.fetch(50))
530 self._rowbuffer = collections.deque(new_rows)
532 def __aiter__(self):
533 return self
535 async def __anext__(self):
536 if not self._rowbuffer:
537 self._buffer_rows()
539 while True:
540 while self._rowbuffer:
541 yield self._rowbuffer.popleft()
543 self._buffer_rows()
544 if not self._rowbuffer:
545 break
547 def fetchone(self):
548 if not self._rowbuffer:
549 self._buffer_rows()
550 if not self._rowbuffer:
551 return None
552 return self._rowbuffer.popleft()
554 def fetchmany(self, size=None):
555 if size is None:
556 return self.fetchall()
558 if not self._rowbuffer:
559 self._buffer_rows()
561 buf = list(self._rowbuffer)
562 lb = len(buf)
563 if size > lb:
564 buf.extend(
565 self._adapt_connection.await_(self._cursor.fetch(size - lb))
566 )
568 result = buf[0:size]
569 self._rowbuffer = collections.deque(buf[size:])
570 return result
572 def fetchall(self):
573 ret = list(self._rowbuffer) + list(
574 self._adapt_connection.await_(self._all())
575 )
576 self._rowbuffer.clear()
577 return ret
579 async def _all(self):
580 rows = []
582 # TODO: looks like we have to hand-roll some kind of batching here.
583 # hardcoding for the moment but this should be improved.
584 while True:
585 batch = await self._cursor.fetch(1000)
586 if batch:
587 rows.extend(batch)
588 continue
589 else:
590 break
591 return rows
593 def executemany(self, operation, seq_of_parameters):
594 raise NotImplementedError(
595 "server side cursor doesn't support executemany yet"
596 )
599class AsyncAdapt_asyncpg_connection(AdaptedConnection):
600 __slots__ = (
601 "dbapi",
602 "_connection",
603 "isolation_level",
604 "_isolation_setting",
605 "readonly",
606 "deferrable",
607 "_transaction",
608 "_started",
609 "_prepared_statement_cache",
610 "_invalidate_schema_cache_asof",
611 "_execute_mutex",
612 )
614 await_ = staticmethod(await_only)
616 def __init__(self, dbapi, connection, prepared_statement_cache_size=100):
617 self.dbapi = dbapi
618 self._connection = connection
619 self.isolation_level = self._isolation_setting = "read_committed"
620 self.readonly = False
621 self.deferrable = False
622 self._transaction = None
623 self._started = False
624 self._invalidate_schema_cache_asof = time.time()
625 self._execute_mutex = asyncio.Lock()
627 if prepared_statement_cache_size:
628 self._prepared_statement_cache = util.LRUCache(
629 prepared_statement_cache_size
630 )
631 else:
632 self._prepared_statement_cache = None
634 async def _check_type_cache_invalidation(self, invalidate_timestamp):
635 if invalidate_timestamp > self._invalidate_schema_cache_asof:
636 await self._connection.reload_schema_state()
637 self._invalidate_schema_cache_asof = invalidate_timestamp
639 async def _prepare(self, operation, invalidate_timestamp):
640 await self._check_type_cache_invalidation(invalidate_timestamp)
642 cache = self._prepared_statement_cache
643 if cache is None:
644 prepared_stmt = await self._connection.prepare(operation)
645 attributes = prepared_stmt.get_attributes()
646 return prepared_stmt, attributes
648 # asyncpg uses a type cache for the "attributes" which seems to go
649 # stale independently of the PreparedStatement itself, so place that
650 # collection in the cache as well.
651 if operation in cache:
652 prepared_stmt, attributes, cached_timestamp = cache[operation]
654 # preparedstatements themselves also go stale for certain DDL
655 # changes such as size of a VARCHAR changing, so there is also
656 # a cross-connection invalidation timestamp
657 if cached_timestamp > invalidate_timestamp:
658 return prepared_stmt, attributes
660 prepared_stmt = await self._connection.prepare(operation)
661 attributes = prepared_stmt.get_attributes()
662 cache[operation] = (prepared_stmt, attributes, time.time())
664 return prepared_stmt, attributes
666 def _handle_exception(self, error):
667 if self._connection.is_closed():
668 self._transaction = None
669 self._started = False
671 if not isinstance(error, AsyncAdapt_asyncpg_dbapi.Error):
672 exception_mapping = self.dbapi._asyncpg_error_translate
674 for super_ in type(error).__mro__:
675 if super_ in exception_mapping:
676 translated_error = exception_mapping[super_](
677 "%s: %s" % (type(error), error)
678 )
679 translated_error.pgcode = (
680 translated_error.sqlstate
681 ) = getattr(error, "sqlstate", None)
682 raise translated_error from error
683 else:
684 raise error
685 else:
686 raise error
688 @property
689 def autocommit(self):
690 return self.isolation_level == "autocommit"
692 @autocommit.setter
693 def autocommit(self, value):
694 if value:
695 self.isolation_level = "autocommit"
696 else:
697 self.isolation_level = self._isolation_setting
699 def set_isolation_level(self, level):
700 if self._started:
701 self.rollback()
702 self.isolation_level = self._isolation_setting = level
704 async def _start_transaction(self):
705 if self.isolation_level == "autocommit":
706 return
708 try:
709 self._transaction = self._connection.transaction(
710 isolation=self.isolation_level,
711 readonly=self.readonly,
712 deferrable=self.deferrable,
713 )
714 await self._transaction.start()
715 except Exception as error:
716 self._handle_exception(error)
717 else:
718 self._started = True
720 def cursor(self, server_side=False):
721 if server_side:
722 return AsyncAdapt_asyncpg_ss_cursor(self)
723 else:
724 return AsyncAdapt_asyncpg_cursor(self)
726 def rollback(self):
727 if self._started:
728 try:
729 self.await_(self._transaction.rollback())
730 except Exception as error:
731 self._handle_exception(error)
732 finally:
733 self._transaction = None
734 self._started = False
736 def commit(self):
737 if self._started:
738 try:
739 self.await_(self._transaction.commit())
740 except Exception as error:
741 self._handle_exception(error)
742 finally:
743 self._transaction = None
744 self._started = False
746 def close(self):
747 self.rollback()
749 self.await_(self._connection.close())
751 def terminate(self):
752 self._connection.terminate()
755class AsyncAdaptFallback_asyncpg_connection(AsyncAdapt_asyncpg_connection):
756 __slots__ = ()
758 await_ = staticmethod(await_fallback)
761class AsyncAdapt_asyncpg_dbapi:
762 def __init__(self, asyncpg):
763 self.asyncpg = asyncpg
764 self.paramstyle = "format"
766 def connect(self, *arg, **kw):
767 async_fallback = kw.pop("async_fallback", False)
768 prepared_statement_cache_size = kw.pop(
769 "prepared_statement_cache_size", 100
770 )
771 if util.asbool(async_fallback):
772 return AsyncAdaptFallback_asyncpg_connection(
773 self,
774 await_fallback(self.asyncpg.connect(*arg, **kw)),
775 prepared_statement_cache_size=prepared_statement_cache_size,
776 )
777 else:
778 return AsyncAdapt_asyncpg_connection(
779 self,
780 await_only(self.asyncpg.connect(*arg, **kw)),
781 prepared_statement_cache_size=prepared_statement_cache_size,
782 )
784 class Error(Exception):
785 pass
787 class Warning(Exception): # noqa
788 pass
790 class InterfaceError(Error):
791 pass
793 class DatabaseError(Error):
794 pass
796 class InternalError(DatabaseError):
797 pass
799 class OperationalError(DatabaseError):
800 pass
802 class ProgrammingError(DatabaseError):
803 pass
805 class IntegrityError(DatabaseError):
806 pass
808 class DataError(DatabaseError):
809 pass
811 class NotSupportedError(DatabaseError):
812 pass
814 class InternalServerError(InternalError):
815 pass
817 class InvalidCachedStatementError(NotSupportedError):
818 def __init__(self, message):
819 super(
820 AsyncAdapt_asyncpg_dbapi.InvalidCachedStatementError, self
821 ).__init__(
822 message + " (SQLAlchemy asyncpg dialect will now invalidate "
823 "all prepared caches in response to this exception)",
824 )
826 @util.memoized_property
827 def _asyncpg_error_translate(self):
828 import asyncpg
830 return {
831 asyncpg.exceptions.IntegrityConstraintViolationError: self.IntegrityError, # noqa: E501
832 asyncpg.exceptions.PostgresError: self.Error,
833 asyncpg.exceptions.SyntaxOrAccessError: self.ProgrammingError,
834 asyncpg.exceptions.InterfaceError: self.InterfaceError,
835 asyncpg.exceptions.InvalidCachedStatementError: self.InvalidCachedStatementError, # noqa: E501
836 asyncpg.exceptions.InternalServerError: self.InternalServerError,
837 }
839 def Binary(self, value):
840 return value
842 STRING = util.symbol("STRING")
843 TIMESTAMP = util.symbol("TIMESTAMP")
844 TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ")
845 TIME = util.symbol("TIME")
846 TIME_W_TZ = util.symbol("TIME_W_TZ")
847 DATE = util.symbol("DATE")
848 INTERVAL = util.symbol("INTERVAL")
849 NUMBER = util.symbol("NUMBER")
850 FLOAT = util.symbol("FLOAT")
851 BOOLEAN = util.symbol("BOOLEAN")
852 INTEGER = util.symbol("INTEGER")
853 BIGINTEGER = util.symbol("BIGINTEGER")
854 BYTES = util.symbol("BYTES")
855 DECIMAL = util.symbol("DECIMAL")
856 JSON = util.symbol("JSON")
857 JSONB = util.symbol("JSONB")
858 ENUM = util.symbol("ENUM")
859 UUID = util.symbol("UUID")
860 BYTEA = util.symbol("BYTEA")
862 DATETIME = TIMESTAMP
863 BINARY = BYTEA
866_pg_types = {
867 AsyncAdapt_asyncpg_dbapi.STRING: "varchar",
868 AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp",
869 AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone",
870 AsyncAdapt_asyncpg_dbapi.DATE: "date",
871 AsyncAdapt_asyncpg_dbapi.TIME: "time",
872 AsyncAdapt_asyncpg_dbapi.TIME_W_TZ: "time with time zone",
873 AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval",
874 AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric",
875 AsyncAdapt_asyncpg_dbapi.FLOAT: "float",
876 AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool",
877 AsyncAdapt_asyncpg_dbapi.INTEGER: "integer",
878 AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint",
879 AsyncAdapt_asyncpg_dbapi.BYTES: "bytes",
880 AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal",
881 AsyncAdapt_asyncpg_dbapi.JSON: "json",
882 AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb",
883 AsyncAdapt_asyncpg_dbapi.ENUM: "enum",
884 AsyncAdapt_asyncpg_dbapi.UUID: "uuid",
885 AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea",
886}
889class PGDialect_asyncpg(PGDialect):
890 driver = "asyncpg"
891 supports_statement_cache = True
893 supports_unicode_statements = True
894 supports_server_side_cursors = True
896 supports_unicode_binds = True
897 has_terminate = True
899 default_paramstyle = "format"
900 supports_sane_multi_rowcount = False
901 execution_ctx_cls = PGExecutionContext_asyncpg
902 statement_compiler = PGCompiler_asyncpg
903 preparer = PGIdentifierPreparer_asyncpg
905 use_setinputsizes = True
907 use_native_uuid = True
909 colspecs = util.update_copy(
910 PGDialect.colspecs,
911 {
912 sqltypes.Time: AsyncpgTime,
913 sqltypes.Date: AsyncpgDate,
914 sqltypes.DateTime: AsyncpgDateTime,
915 sqltypes.Interval: AsyncPgInterval,
916 INTERVAL: AsyncPgInterval,
917 UUID: AsyncpgUUID,
918 sqltypes.Boolean: AsyncpgBoolean,
919 sqltypes.Integer: AsyncpgInteger,
920 sqltypes.BigInteger: AsyncpgBigInteger,
921 sqltypes.Numeric: AsyncpgNumeric,
922 sqltypes.Float: AsyncpgFloat,
923 sqltypes.JSON: AsyncpgJSON,
924 json.JSONB: AsyncpgJSONB,
925 sqltypes.JSON.JSONPathType: AsyncpgJSONPathType,
926 sqltypes.JSON.JSONIndexType: AsyncpgJSONIndexType,
927 sqltypes.JSON.JSONIntIndexType: AsyncpgJSONIntIndexType,
928 sqltypes.JSON.JSONStrIndexType: AsyncpgJSONStrIndexType,
929 sqltypes.Enum: AsyncPgEnum,
930 OID: AsyncpgOID,
931 REGCLASS: AsyncpgREGCLASS,
932 },
933 )
934 is_async = True
935 _invalidate_schema_cache_asof = 0
937 def _invalidate_schema_cache(self):
938 self._invalidate_schema_cache_asof = time.time()
940 @util.memoized_property
941 def _dbapi_version(self):
942 if self.dbapi and hasattr(self.dbapi, "__version__"):
943 return tuple(
944 [
945 int(x)
946 for x in re.findall(
947 r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
948 )
949 ]
950 )
951 else:
952 return (99, 99, 99)
954 @classmethod
955 def dbapi(cls):
956 return AsyncAdapt_asyncpg_dbapi(__import__("asyncpg"))
958 @util.memoized_property
959 def _isolation_lookup(self):
960 return {
961 "AUTOCOMMIT": "autocommit",
962 "READ COMMITTED": "read_committed",
963 "REPEATABLE READ": "repeatable_read",
964 "SERIALIZABLE": "serializable",
965 }
967 def set_isolation_level(self, connection, level):
968 try:
969 level = self._isolation_lookup[level.replace("_", " ")]
970 except KeyError as err:
971 util.raise_(
972 exc.ArgumentError(
973 "Invalid value '%s' for isolation_level. "
974 "Valid isolation levels for %s are %s"
975 % (level, self.name, ", ".join(self._isolation_lookup))
976 ),
977 replace_context=err,
978 )
980 connection.set_isolation_level(level)
982 def set_readonly(self, connection, value):
983 connection.readonly = value
985 def get_readonly(self, connection):
986 return connection.readonly
988 def set_deferrable(self, connection, value):
989 connection.deferrable = value
991 def get_deferrable(self, connection):
992 return connection.deferrable
994 def do_terminate(self, dbapi_connection) -> None:
995 dbapi_connection.terminate()
997 def create_connect_args(self, url):
998 opts = url.translate_connect_args(username="user")
1000 opts.update(url.query)
1001 util.coerce_kw_type(opts, "prepared_statement_cache_size", int)
1002 util.coerce_kw_type(opts, "port", int)
1003 return ([], opts)
1005 @classmethod
1006 def get_pool_class(cls, url):
1008 async_fallback = url.query.get("async_fallback", False)
1010 if util.asbool(async_fallback):
1011 return pool.FallbackAsyncAdaptedQueuePool
1012 else:
1013 return pool.AsyncAdaptedQueuePool
1015 def is_disconnect(self, e, connection, cursor):
1016 if connection:
1017 return connection._connection.is_closed()
1018 else:
1019 return isinstance(
1020 e, self.dbapi.InterfaceError
1021 ) and "connection is closed" in str(e)
1023 def do_set_input_sizes(self, cursor, list_of_tuples, context):
1024 if self.positional:
1025 cursor.setinputsizes(
1026 *[dbtype for key, dbtype, sqltype in list_of_tuples]
1027 )
1028 else:
1029 cursor.setinputsizes(
1030 **{
1031 key: dbtype
1032 for key, dbtype, sqltype in list_of_tuples
1033 if dbtype
1034 }
1035 )
1037 async def setup_asyncpg_json_codec(self, conn):
1038 """set up JSON codec for asyncpg.
1040 This occurs for all new connections and
1041 can be overridden by third party dialects.
1043 .. versionadded:: 1.4.27
1045 """
1047 asyncpg_connection = conn._connection
1048 deserializer = self._json_deserializer or _py_json.loads
1050 def _json_decoder(bin_value):
1051 return deserializer(bin_value.decode())
1053 await asyncpg_connection.set_type_codec(
1054 "json",
1055 encoder=str.encode,
1056 decoder=_json_decoder,
1057 schema="pg_catalog",
1058 format="binary",
1059 )
1061 async def setup_asyncpg_jsonb_codec(self, conn):
1062 """set up JSONB codec for asyncpg.
1064 This occurs for all new connections and
1065 can be overridden by third party dialects.
1067 .. versionadded:: 1.4.27
1069 """
1071 asyncpg_connection = conn._connection
1072 deserializer = self._json_deserializer or _py_json.loads
1074 def _jsonb_encoder(str_value):
1075 # \x01 is the prefix for jsonb used by PostgreSQL.
1076 # asyncpg requires it when format='binary'
1077 return b"\x01" + str_value.encode()
1079 deserializer = self._json_deserializer or _py_json.loads
1081 def _jsonb_decoder(bin_value):
1082 # the byte is the \x01 prefix for jsonb used by PostgreSQL.
1083 # asyncpg returns it when format='binary'
1084 return deserializer(bin_value[1:].decode())
1086 await asyncpg_connection.set_type_codec(
1087 "jsonb",
1088 encoder=_jsonb_encoder,
1089 decoder=_jsonb_decoder,
1090 schema="pg_catalog",
1091 format="binary",
1092 )
1094 def on_connect(self):
1095 """on_connect for asyncpg
1097 A major component of this for asyncpg is to set up type decoders at the
1098 asyncpg level.
1100 See https://github.com/MagicStack/asyncpg/issues/623 for
1101 notes on JSON/JSONB implementation.
1103 """
1105 super_connect = super(PGDialect_asyncpg, self).on_connect()
1107 def connect(conn):
1108 conn.await_(self.setup_asyncpg_json_codec(conn))
1109 conn.await_(self.setup_asyncpg_jsonb_codec(conn))
1110 if super_connect is not None:
1111 super_connect(conn)
1113 return connect
1115 def get_driver_connection(self, connection):
1116 return connection._connection
1119dialect = PGDialect_asyncpg