1# engine/strategies.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Strategies for creating new instances of Engine types.
9
10These are semi-private implementation classes which provide the
11underlying behavior for the "strategy" keyword argument available on
12:func:`~sqlalchemy.engine.create_engine`. Current available options are
13``plain``, ``threadlocal``, and ``mock``.
14
15New strategies can be added via new ``EngineStrategy`` classes.
16"""
17
18from operator import attrgetter
19
20from . import base
21from . import threadlocal
22from . import url
23from .. import event
24from .. import pool as poollib
25from .. import util
26from ..sql import schema
27
28
29strategies = {}
30
31
32class EngineStrategy(object):
33 """An adaptor that processes input arguments and produces an Engine.
34
35 Provides a ``create`` method that receives input arguments and
36 produces an instance of base.Engine or a subclass.
37
38 """
39
40 def __init__(self):
41 strategies[self.name] = self
42
43 def create(self, *args, **kwargs):
44 """Given arguments, returns a new Engine instance."""
45
46 raise NotImplementedError()
47
48
49class DefaultEngineStrategy(EngineStrategy):
50 """Base class for built-in strategies."""
51
52 def create(self, name_or_url, **kwargs):
53 # create url.URL object
54 u = url.make_url(name_or_url)
55
56 plugins = u._instantiate_plugins(kwargs)
57
58 u.query.pop("plugin", None)
59 kwargs.pop("plugins", None)
60
61 entrypoint = u._get_entrypoint()
62 dialect_cls = entrypoint.get_dialect_cls(u)
63
64 if kwargs.pop("_coerce_config", False):
65
66 def pop_kwarg(key, default=None):
67 value = kwargs.pop(key, default)
68 if key in dialect_cls.engine_config_types:
69 value = dialect_cls.engine_config_types[key](value)
70 return value
71
72 else:
73 pop_kwarg = kwargs.pop
74
75 dialect_args = {}
76 # consume dialect arguments from kwargs
77 for k in util.get_cls_kwargs(dialect_cls):
78 if k in kwargs:
79 dialect_args[k] = pop_kwarg(k)
80
81 dbapi = kwargs.pop("module", None)
82 if dbapi is None:
83 dbapi_args = {}
84 for k in util.get_func_kwargs(dialect_cls.dbapi):
85 if k in kwargs:
86 dbapi_args[k] = pop_kwarg(k)
87 dbapi = dialect_cls.dbapi(**dbapi_args)
88
89 dialect_args["dbapi"] = dbapi
90
91 for plugin in plugins:
92 plugin.handle_dialect_kwargs(dialect_cls, dialect_args)
93
94 # create dialect
95 dialect = dialect_cls(**dialect_args)
96
97 # assemble connection arguments
98 (cargs, cparams) = dialect.create_connect_args(u)
99 cparams.update(pop_kwarg("connect_args", {}))
100 cargs = list(cargs) # allow mutability
101
102 # look for existing pool or create
103 pool = pop_kwarg("pool", None)
104 if pool is None:
105
106 def connect(connection_record=None):
107 if dialect._has_events:
108 for fn in dialect.dispatch.do_connect:
109 connection = fn(
110 dialect, connection_record, cargs, cparams
111 )
112 if connection is not None:
113 return connection
114 return dialect.connect(*cargs, **cparams)
115
116 creator = pop_kwarg("creator", connect)
117
118 poolclass = pop_kwarg("poolclass", None)
119 if poolclass is None:
120 poolclass = dialect_cls.get_pool_class(u)
121 pool_args = {"dialect": dialect}
122
123 # consume pool arguments from kwargs, translating a few of
124 # the arguments
125 translate = {
126 "logging_name": "pool_logging_name",
127 "echo": "echo_pool",
128 "timeout": "pool_timeout",
129 "recycle": "pool_recycle",
130 "events": "pool_events",
131 "use_threadlocal": "pool_threadlocal",
132 "reset_on_return": "pool_reset_on_return",
133 "pre_ping": "pool_pre_ping",
134 "use_lifo": "pool_use_lifo",
135 }
136 for k in util.get_cls_kwargs(poolclass):
137 tk = translate.get(k, k)
138 if tk in kwargs:
139 pool_args[k] = pop_kwarg(tk)
140
141 for plugin in plugins:
142 plugin.handle_pool_kwargs(poolclass, pool_args)
143
144 pool = poolclass(creator, **pool_args)
145 else:
146 if isinstance(pool, poollib.dbapi_proxy._DBProxy):
147 pool = pool.get_pool(*cargs, **cparams)
148 else:
149 pool = pool
150
151 pool._dialect = dialect
152
153 # create engine.
154 engineclass = self.engine_cls
155 engine_args = {}
156 for k in util.get_cls_kwargs(engineclass):
157 if k in kwargs:
158 engine_args[k] = pop_kwarg(k)
159
160 _initialize = kwargs.pop("_initialize", True)
161
162 # all kwargs should be consumed
163 if kwargs:
164 raise TypeError(
165 "Invalid argument(s) %s sent to create_engine(), "
166 "using configuration %s/%s/%s. Please check that the "
167 "keyword arguments are appropriate for this combination "
168 "of components."
169 % (
170 ",".join("'%s'" % k for k in kwargs),
171 dialect.__class__.__name__,
172 pool.__class__.__name__,
173 engineclass.__name__,
174 )
175 )
176
177 engine = engineclass(pool, dialect, u, **engine_args)
178
179 if _initialize:
180 do_on_connect = dialect.on_connect()
181 if do_on_connect:
182
183 def on_connect(dbapi_connection, connection_record):
184 conn = getattr(
185 dbapi_connection, "_sqla_unwrap", dbapi_connection
186 )
187 if conn is None:
188 return
189 do_on_connect(conn)
190
191 event.listen(pool, "first_connect", on_connect)
192 event.listen(pool, "connect", on_connect)
193
194 def first_connect(dbapi_connection, connection_record):
195 c = base.Connection(
196 engine, connection=dbapi_connection, _has_events=False
197 )
198 c._execution_options = util.immutabledict()
199 dialect.initialize(c)
200 dialect.do_rollback(c.connection)
201
202 event.listen(
203 pool,
204 "first_connect",
205 first_connect,
206 _once_unless_exception=True,
207 )
208
209 dialect_cls.engine_created(engine)
210 if entrypoint is not dialect_cls:
211 entrypoint.engine_created(engine)
212
213 for plugin in plugins:
214 plugin.engine_created(engine)
215
216 return engine
217
218
219class PlainEngineStrategy(DefaultEngineStrategy):
220 """Strategy for configuring a regular Engine."""
221
222 name = "plain"
223 engine_cls = base.Engine
224
225
226PlainEngineStrategy()
227
228
229class ThreadLocalEngineStrategy(DefaultEngineStrategy):
230 """Strategy for configuring an Engine with threadlocal behavior."""
231
232 name = "threadlocal"
233 engine_cls = threadlocal.TLEngine
234
235
236ThreadLocalEngineStrategy()
237
238
239class MockEngineStrategy(EngineStrategy):
240 """Strategy for configuring an Engine-like object with mocked execution.
241
242 Produces a single mock Connectable object which dispatches
243 statement execution to a passed-in function.
244
245 """
246
247 name = "mock"
248
249 def create(self, name_or_url, executor, **kwargs):
250 # create url.URL object
251 u = url.make_url(name_or_url)
252
253 dialect_cls = u.get_dialect()
254
255 dialect_args = {}
256 # consume dialect arguments from kwargs
257 for k in util.get_cls_kwargs(dialect_cls):
258 if k in kwargs:
259 dialect_args[k] = kwargs.pop(k)
260
261 # create dialect
262 dialect = dialect_cls(**dialect_args)
263
264 return MockEngineStrategy.MockConnection(dialect, executor)
265
266 class MockConnection(base.Connectable):
267 def __init__(self, dialect, execute):
268 self._dialect = dialect
269 self.execute = execute
270
271 engine = property(lambda s: s)
272 dialect = property(attrgetter("_dialect"))
273 name = property(lambda s: s._dialect.name)
274
275 schema_for_object = schema._schema_getter(None)
276
277 def contextual_connect(self, **kwargs):
278 return self
279
280 def connect(self, **kwargs):
281 return self
282
283 def execution_options(self, **kw):
284 return self
285
286 def compiler(self, statement, parameters, **kwargs):
287 return self._dialect.compiler(
288 statement, parameters, engine=self, **kwargs
289 )
290
291 def create(self, entity, **kwargs):
292 kwargs["checkfirst"] = False
293 from sqlalchemy.engine import ddl
294
295 ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse_single(
296 entity
297 )
298
299 def drop(self, entity, **kwargs):
300 kwargs["checkfirst"] = False
301 from sqlalchemy.engine import ddl
302
303 ddl.SchemaDropper(self.dialect, self, **kwargs).traverse_single(
304 entity
305 )
306
307 def _run_visitor(
308 self, visitorcallable, element, connection=None, **kwargs
309 ):
310 kwargs["checkfirst"] = False
311 visitorcallable(self.dialect, self, **kwargs).traverse_single(
312 element
313 )
314
315 def execute(self, object_, *multiparams, **params):
316 raise NotImplementedError()
317
318
319MockEngineStrategy()