1# engine/util.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8from __future__ import annotations
9
10import typing
11from typing import Any
12from typing import Callable
13from typing import Optional
14from typing import TypeVar
15
16from .. import exc
17from .. import util
18from ..util._has_cy import HAS_CYEXTENSION
19from ..util.typing import Protocol
20from ..util.typing import Self
21
22if typing.TYPE_CHECKING or not HAS_CYEXTENSION:
23 from ._py_util import _distill_params_20 as _distill_params_20
24 from ._py_util import _distill_raw_params as _distill_raw_params
25else:
26 from sqlalchemy.cyextension.util import ( # noqa: F401
27 _distill_params_20 as _distill_params_20,
28 )
29 from sqlalchemy.cyextension.util import ( # noqa: F401
30 _distill_raw_params as _distill_raw_params,
31 )
32
33_C = TypeVar("_C", bound=Callable[[], Any])
34
35
36def connection_memoize(key: str) -> Callable[[_C], _C]:
37 """Decorator, memoize a function in a connection.info stash.
38
39 Only applicable to functions which take no arguments other than a
40 connection. The memo will be stored in ``connection.info[key]``.
41 """
42
43 @util.decorator
44 def decorated(fn, self, connection): # type: ignore
45 connection = connection.connect()
46 try:
47 return connection.info[key]
48 except KeyError:
49 connection.info[key] = val = fn(self, connection)
50 return val
51
52 return decorated
53
54
55class _TConsSubject(Protocol):
56 _trans_context_manager: Optional[TransactionalContext]
57
58
59class TransactionalContext:
60 """Apply Python context manager behavior to transaction objects.
61
62 Performs validation to ensure the subject of the transaction is not
63 used if the transaction were ended prematurely.
64
65 """
66
67 __slots__ = ("_outer_trans_ctx", "_trans_subject", "__weakref__")
68
69 _trans_subject: Optional[_TConsSubject]
70
71 def _transaction_is_active(self) -> bool:
72 raise NotImplementedError()
73
74 def _transaction_is_closed(self) -> bool:
75 raise NotImplementedError()
76
77 def _rollback_can_be_called(self) -> bool:
78 """indicates the object is in a state that is known to be acceptable
79 for rollback() to be called.
80
81 This does not necessarily mean rollback() will succeed or not raise
82 an error, just that there is currently no state detected that indicates
83 rollback() would fail or emit warnings.
84
85 It also does not mean that there's a transaction in progress, as
86 it is usually safe to call rollback() even if no transaction is
87 present.
88
89 .. versionadded:: 1.4.28
90
91 """
92 raise NotImplementedError()
93
94 def _get_subject(self) -> _TConsSubject:
95 raise NotImplementedError()
96
97 def commit(self) -> None:
98 raise NotImplementedError()
99
100 def rollback(self) -> None:
101 raise NotImplementedError()
102
103 def close(self) -> None:
104 raise NotImplementedError()
105
106 @classmethod
107 def _trans_ctx_check(cls, subject: _TConsSubject) -> None:
108 trans_context = subject._trans_context_manager
109 if trans_context:
110 if not trans_context._transaction_is_active():
111 raise exc.InvalidRequestError(
112 "Can't operate on closed transaction inside context "
113 "manager. Please complete the context manager "
114 "before emitting further commands."
115 )
116
117 def __enter__(self) -> Self:
118 subject = self._get_subject()
119
120 # none for outer transaction, may be non-None for nested
121 # savepoint, legacy nesting cases
122 trans_context = subject._trans_context_manager
123 self._outer_trans_ctx = trans_context
124
125 self._trans_subject = subject
126 subject._trans_context_manager = self
127 return self
128
129 def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
130 subject = getattr(self, "_trans_subject", None)
131
132 # simplistically we could assume that
133 # "subject._trans_context_manager is self". However, any calling
134 # code that is manipulating __exit__ directly would break this
135 # assumption. alembic context manager
136 # is an example of partial use that just calls __exit__ and
137 # not __enter__ at the moment. it's safe to assume this is being done
138 # in the wild also
139 out_of_band_exit = (
140 subject is None or subject._trans_context_manager is not self
141 )
142
143 if type_ is None and self._transaction_is_active():
144 try:
145 self.commit()
146 except:
147 with util.safe_reraise():
148 if self._rollback_can_be_called():
149 self.rollback()
150 finally:
151 if not out_of_band_exit:
152 assert subject is not None
153 subject._trans_context_manager = self._outer_trans_ctx
154 self._trans_subject = self._outer_trans_ctx = None
155 else:
156 try:
157 if not self._transaction_is_active():
158 if not self._transaction_is_closed():
159 self.close()
160 else:
161 if self._rollback_can_be_called():
162 self.rollback()
163 finally:
164 if not out_of_band_exit:
165 assert subject is not None
166 subject._trans_context_manager = self._outer_trans_ctx
167 self._trans_subject = self._outer_trans_ctx = None