1# engine/util.py 
    2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: https://www.opensource.org/licenses/mit-license.php 
    7 
    8from __future__ import annotations 
    9 
    10from typing import Any 
    11from typing import Callable 
    12from typing import Optional 
    13from typing import Protocol 
    14from typing import TypeVar 
    15 
    16from ._util_cy import _distill_params_20 as _distill_params_20  # noqa: F401 
    17from ._util_cy import _distill_raw_params as _distill_raw_params  # noqa: F401 
    18from .. import exc 
    19from .. import util 
    20from ..util.typing import Self 
    21 
    22_C = TypeVar("_C", bound=Callable[[], Any]) 
    23 
    24 
    25def connection_memoize(key: str) -> Callable[[_C], _C]: 
    26    """Decorator, memoize a function in a connection.info stash. 
    27 
    28    Only applicable to functions which take no arguments other than a 
    29    connection.  The memo will be stored in ``connection.info[key]``. 
    30    """ 
    31 
    32    @util.decorator 
    33    def decorated(fn, self, connection):  # type: ignore 
    34        connection = connection.connect() 
    35        try: 
    36            return connection.info[key] 
    37        except KeyError: 
    38            connection.info[key] = val = fn(self, connection) 
    39            return val 
    40 
    41    return decorated 
    42 
    43 
    44class _TConsSubject(Protocol): 
    45    _trans_context_manager: Optional[TransactionalContext] 
    46 
    47 
    48class TransactionalContext: 
    49    """Apply Python context manager behavior to transaction objects. 
    50 
    51    Performs validation to ensure the subject of the transaction is not 
    52    used if the transaction were ended prematurely. 
    53 
    54    """ 
    55 
    56    __slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__") 
    57 
    58    _trans_subject: Optional[_TConsSubject] 
    59 
    60    def _transaction_is_active(self) -> bool: 
    61        raise NotImplementedError() 
    62 
    63    def _transaction_is_closed(self) -> bool: 
    64        raise NotImplementedError() 
    65 
    66    def _rollback_can_be_called(self) -> bool: 
    67        """indicates the object is in a state that is known to be acceptable 
    68        for rollback() to be called. 
    69 
    70        This does not necessarily mean rollback() will succeed or not raise 
    71        an error, just that there is currently no state detected that indicates 
    72        rollback() would fail or emit warnings. 
    73 
    74        It also does not mean that there's a transaction in progress, as 
    75        it is usually safe to call rollback() even if no transaction is 
    76        present. 
    77 
    78        .. versionadded:: 1.4.28 
    79 
    80        """ 
    81        raise NotImplementedError() 
    82 
    83    def _get_subject(self) -> _TConsSubject: 
    84        raise NotImplementedError() 
    85 
    86    def commit(self) -> None: 
    87        raise NotImplementedError() 
    88 
    89    def rollback(self) -> None: 
    90        raise NotImplementedError() 
    91 
    92    def close(self) -> None: 
    93        raise NotImplementedError() 
    94 
    95    @classmethod 
    96    def _trans_ctx_check(cls, subject: _TConsSubject) -> None: 
    97        trans_context = subject._trans_context_manager 
    98        if trans_context: 
    99            if not trans_context._transaction_is_active(): 
    100                raise exc.InvalidRequestError( 
    101                    "Can't operate on closed transaction inside context " 
    102                    "manager.  Please complete the context manager " 
    103                    "before emitting further commands." 
    104                ) 
    105 
    106    def __enter__(self) -> Self: 
    107        subject = self._get_subject() 
    108 
    109        # none for outer transaction, may be non-None for nested 
    110        # savepoint, legacy nesting cases 
    111        trans_context = subject._trans_context_manager 
    112        self._outer_trans_ctx = trans_context 
    113 
    114        self._trans_subject = subject 
    115        subject._trans_context_manager = self 
    116        return self 
    117 
    118    def __exit__(self, type_: Any, value: Any, traceback: Any) -> None: 
    119        subject = getattr(self, "_trans_subject", None) 
    120 
    121        # simplistically we could assume that 
    122        # "subject._trans_context_manager is self".  However, any calling 
    123        # code that is manipulating __exit__ directly would break this 
    124        # assumption.  alembic context manager 
    125        # is an example of partial use that just calls __exit__ and 
    126        # not __enter__ at the moment.  it's safe to assume this is being done 
    127        # in the wild also 
    128        out_of_band_exit = ( 
    129            subject is None or subject._trans_context_manager is not self 
    130        ) 
    131 
    132        if type_ is None and self._transaction_is_active(): 
    133            try: 
    134                self.commit() 
    135            except: 
    136                with util.safe_reraise(): 
    137                    if self._rollback_can_be_called(): 
    138                        self.rollback() 
    139            finally: 
    140                if not out_of_band_exit: 
    141                    assert subject is not None 
    142                    subject._trans_context_manager = self._outer_trans_ctx 
    143                self._trans_subject = self._outer_trans_ctx = None 
    144        else: 
    145            try: 
    146                if not self._transaction_is_active(): 
    147                    if not self._transaction_is_closed(): 
    148                        self.close() 
    149                else: 
    150                    if self._rollback_can_be_called(): 
    151                        self.rollback() 
    152            finally: 
    153                if not out_of_band_exit: 
    154                    assert subject is not None 
    155                    subject._trans_context_manager = self._outer_trans_ctx 
    156                self._trans_subject = self._outer_trans_ctx = None