1# engine/threadlocal.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Provides a thread-local transactional wrapper around the root Engine class.
9
10The ``threadlocal`` module is invoked when using the
11``strategy="threadlocal"`` flag with :func:`~sqlalchemy.engine.create_engine`.
12This module is semi-private and is invoked automatically when the threadlocal
13engine strategy is used.
14"""
15
16import weakref
17
18from . import base
19from .. import util
20
21
22class TLConnection(base.Connection):
23 def __init__(self, *arg, **kw):
24 super(TLConnection, self).__init__(*arg, **kw)
25 self.__opencount = 0
26
27 def _increment_connect(self):
28 self.__opencount += 1
29 return self
30
31 def close(self):
32 if self.__opencount == 1:
33 base.Connection.close(self)
34 self.__opencount -= 1
35
36 def _force_close(self):
37 self.__opencount = 0
38 base.Connection.close(self)
39
40
41class TLEngine(base.Engine):
42 """An Engine that includes support for thread-local managed
43 transactions.
44
45 """
46
47 _tl_connection_cls = TLConnection
48
49 @util.deprecated(
50 "1.3",
51 "The 'threadlocal' engine strategy is deprecated, and will be "
52 "removed in a future release. The strategy is no longer relevant "
53 "to modern usage patterns (including that of the ORM "
54 ":class:`.Session` object) which make use of a "
55 ":class:`_engine.Connection` "
56 "object in order to invoke statements.",
57 )
58 def __init__(self, *args, **kwargs):
59 super(TLEngine, self).__init__(*args, **kwargs)
60 self._connections = util.threading.local()
61
62 def contextual_connect(self, **kw):
63 return self._contextual_connect(**kw)
64
65 def _contextual_connect(self, **kw):
66 if not hasattr(self._connections, "conn"):
67 connection = None
68 else:
69 connection = self._connections.conn()
70
71 if connection is None or connection.closed:
72 # guards against pool-level reapers, if desired.
73 # or not connection.connection.is_valid:
74 connection = self._tl_connection_cls(
75 self,
76 self._wrap_pool_connect(self.pool.connect, connection),
77 **kw
78 )
79 self._connections.conn = weakref.ref(connection)
80
81 return connection._increment_connect()
82
83 def begin_twophase(self, xid=None):
84 if not hasattr(self._connections, "trans"):
85 self._connections.trans = []
86 self._connections.trans.append(
87 self._contextual_connect().begin_twophase(xid=xid)
88 )
89 return self
90
91 def begin_nested(self):
92 if not hasattr(self._connections, "trans"):
93 self._connections.trans = []
94 self._connections.trans.append(
95 self._contextual_connect().begin_nested()
96 )
97 return self
98
99 def begin(self):
100 if not hasattr(self._connections, "trans"):
101 self._connections.trans = []
102 self._connections.trans.append(self._contextual_connect().begin())
103 return self
104
105 def __enter__(self):
106 return self
107
108 def __exit__(self, type_, value, traceback):
109 if type_ is None:
110 self.commit()
111 else:
112 self.rollback()
113
114 def prepare(self):
115 if (
116 not hasattr(self._connections, "trans")
117 or not self._connections.trans
118 ):
119 return
120 self._connections.trans[-1].prepare()
121
122 def commit(self):
123 if (
124 not hasattr(self._connections, "trans")
125 or not self._connections.trans
126 ):
127 return
128 trans = self._connections.trans.pop(-1)
129 trans.commit()
130
131 def rollback(self):
132 if (
133 not hasattr(self._connections, "trans")
134 or not self._connections.trans
135 ):
136 return
137 trans = self._connections.trans.pop(-1)
138 trans.rollback()
139
140 def dispose(self):
141 self._connections = util.threading.local()
142 super(TLEngine, self).dispose()
143
144 @property
145 def closed(self):
146 return (
147 not hasattr(self._connections, "conn")
148 or self._connections.conn() is None
149 or self._connections.conn().closed
150 )
151
152 def close(self):
153 if not self.closed:
154 self._contextual_connect().close()
155 connection = self._connections.conn()
156 connection._force_close()
157 del self._connections.conn
158 self._connections.trans = []
159
160 def __repr__(self):
161 return "TLEngine(%r)" % self.url