1# engine/strategies.py 
    2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: http://www.opensource.org/licenses/mit-license.php 
    7 
    8"""Strategies for creating new instances of Engine types. 
    9 
    10These are semi-private implementation classes which provide the 
    11underlying behavior for the "strategy" keyword argument available on 
    12:func:`~sqlalchemy.engine.create_engine`.  Current available options are 
    13``plain``, ``threadlocal``, and ``mock``. 
    14 
    15New strategies can be added via new ``EngineStrategy`` classes. 
    16""" 
    17 
    18from operator import attrgetter 
    19 
    20from . import base 
    21from . import threadlocal 
    22from . import url 
    23from .. import event 
    24from .. import pool as poollib 
    25from .. import util 
    26from ..sql import schema 
    27 
    28 
    29strategies = {} 
    30 
    31 
    32class EngineStrategy(object): 
    33    """An adaptor that processes input arguments and produces an Engine. 
    34 
    35    Provides a ``create`` method that receives input arguments and 
    36    produces an instance of base.Engine or a subclass. 
    37 
    38    """ 
    39 
    40    def __init__(self): 
    41        strategies[self.name] = self 
    42 
    43    def create(self, *args, **kwargs): 
    44        """Given arguments, returns a new Engine instance.""" 
    45 
    46        raise NotImplementedError() 
    47 
    48 
    49class DefaultEngineStrategy(EngineStrategy): 
    50    """Base class for built-in strategies.""" 
    51 
    52    def create(self, name_or_url, **kwargs): 
    53        # create url.URL object 
    54        u = url.make_url(name_or_url) 
    55 
    56        plugins = u._instantiate_plugins(kwargs) 
    57 
    58        u.query.pop("plugin", None) 
    59        kwargs.pop("plugins", None) 
    60 
    61        entrypoint = u._get_entrypoint() 
    62        dialect_cls = entrypoint.get_dialect_cls(u) 
    63 
    64        if kwargs.pop("_coerce_config", False): 
    65 
    66            def pop_kwarg(key, default=None): 
    67                value = kwargs.pop(key, default) 
    68                if key in dialect_cls.engine_config_types: 
    69                    value = dialect_cls.engine_config_types[key](value) 
    70                return value 
    71 
    72        else: 
    73            pop_kwarg = kwargs.pop 
    74 
    75        dialect_args = {} 
    76        # consume dialect arguments from kwargs 
    77        for k in util.get_cls_kwargs(dialect_cls): 
    78            if k in kwargs: 
    79                dialect_args[k] = pop_kwarg(k) 
    80 
    81        dbapi = kwargs.pop("module", None) 
    82        if dbapi is None: 
    83            dbapi_args = {} 
    84            for k in util.get_func_kwargs(dialect_cls.dbapi): 
    85                if k in kwargs: 
    86                    dbapi_args[k] = pop_kwarg(k) 
    87            dbapi = dialect_cls.dbapi(**dbapi_args) 
    88 
    89        dialect_args["dbapi"] = dbapi 
    90 
    91        for plugin in plugins: 
    92            plugin.handle_dialect_kwargs(dialect_cls, dialect_args) 
    93 
    94        # create dialect 
    95        dialect = dialect_cls(**dialect_args) 
    96 
    97        # assemble connection arguments 
    98        (cargs, cparams) = dialect.create_connect_args(u) 
    99        cparams.update(pop_kwarg("connect_args", {})) 
    100        cargs = list(cargs)  # allow mutability 
    101 
    102        # look for existing pool or create 
    103        pool = pop_kwarg("pool", None) 
    104        if pool is None: 
    105 
    106            def connect(connection_record=None): 
    107                if dialect._has_events: 
    108                    for fn in dialect.dispatch.do_connect: 
    109                        connection = fn( 
    110                            dialect, connection_record, cargs, cparams 
    111                        ) 
    112                        if connection is not None: 
    113                            return connection 
    114                return dialect.connect(*cargs, **cparams) 
    115 
    116            creator = pop_kwarg("creator", connect) 
    117 
    118            poolclass = pop_kwarg("poolclass", None) 
    119            if poolclass is None: 
    120                poolclass = dialect_cls.get_pool_class(u) 
    121            pool_args = {"dialect": dialect} 
    122 
    123            # consume pool arguments from kwargs, translating a few of 
    124            # the arguments 
    125            translate = { 
    126                "logging_name": "pool_logging_name", 
    127                "echo": "echo_pool", 
    128                "timeout": "pool_timeout", 
    129                "recycle": "pool_recycle", 
    130                "events": "pool_events", 
    131                "use_threadlocal": "pool_threadlocal", 
    132                "reset_on_return": "pool_reset_on_return", 
    133                "pre_ping": "pool_pre_ping", 
    134                "use_lifo": "pool_use_lifo", 
    135            } 
    136            for k in util.get_cls_kwargs(poolclass): 
    137                tk = translate.get(k, k) 
    138                if tk in kwargs: 
    139                    pool_args[k] = pop_kwarg(tk) 
    140 
    141            for plugin in plugins: 
    142                plugin.handle_pool_kwargs(poolclass, pool_args) 
    143 
    144            pool = poolclass(creator, **pool_args) 
    145        else: 
    146            if isinstance(pool, poollib.dbapi_proxy._DBProxy): 
    147                pool = pool.get_pool(*cargs, **cparams) 
    148            else: 
    149                pool = pool 
    150 
    151            pool._dialect = dialect 
    152 
    153        # create engine. 
    154        engineclass = self.engine_cls 
    155        engine_args = {} 
    156        for k in util.get_cls_kwargs(engineclass): 
    157            if k in kwargs: 
    158                engine_args[k] = pop_kwarg(k) 
    159 
    160        _initialize = kwargs.pop("_initialize", True) 
    161 
    162        # all kwargs should be consumed 
    163        if kwargs: 
    164            raise TypeError( 
    165                "Invalid argument(s) %s sent to create_engine(), " 
    166                "using configuration %s/%s/%s.  Please check that the " 
    167                "keyword arguments are appropriate for this combination " 
    168                "of components." 
    169                % ( 
    170                    ",".join("'%s'" % k for k in kwargs), 
    171                    dialect.__class__.__name__, 
    172                    pool.__class__.__name__, 
    173                    engineclass.__name__, 
    174                ) 
    175            ) 
    176 
    177        engine = engineclass(pool, dialect, u, **engine_args) 
    178 
    179        if _initialize: 
    180            do_on_connect = dialect.on_connect() 
    181            if do_on_connect: 
    182 
    183                def on_connect(dbapi_connection, connection_record): 
    184                    conn = getattr( 
    185                        dbapi_connection, "_sqla_unwrap", dbapi_connection 
    186                    ) 
    187                    if conn is None: 
    188                        return 
    189                    do_on_connect(conn) 
    190 
    191                event.listen(pool, "first_connect", on_connect) 
    192                event.listen(pool, "connect", on_connect) 
    193 
    194            def first_connect(dbapi_connection, connection_record): 
    195                c = base.Connection( 
    196                    engine, connection=dbapi_connection, _has_events=False 
    197                ) 
    198                c._execution_options = util.immutabledict() 
    199                dialect.initialize(c) 
    200                dialect.do_rollback(c.connection) 
    201 
    202            event.listen( 
    203                pool, 
    204                "first_connect", 
    205                first_connect, 
    206                _once_unless_exception=True, 
    207            ) 
    208 
    209        dialect_cls.engine_created(engine) 
    210        if entrypoint is not dialect_cls: 
    211            entrypoint.engine_created(engine) 
    212 
    213        for plugin in plugins: 
    214            plugin.engine_created(engine) 
    215 
    216        return engine 
    217 
    218 
    219class PlainEngineStrategy(DefaultEngineStrategy): 
    220    """Strategy for configuring a regular Engine.""" 
    221 
    222    name = "plain" 
    223    engine_cls = base.Engine 
    224 
    225 
    226PlainEngineStrategy() 
    227 
    228 
    229class ThreadLocalEngineStrategy(DefaultEngineStrategy): 
    230    """Strategy for configuring an Engine with threadlocal behavior.""" 
    231 
    232    name = "threadlocal" 
    233    engine_cls = threadlocal.TLEngine 
    234 
    235 
    236ThreadLocalEngineStrategy() 
    237 
    238 
    239class MockEngineStrategy(EngineStrategy): 
    240    """Strategy for configuring an Engine-like object with mocked execution. 
    241 
    242    Produces a single mock Connectable object which dispatches 
    243    statement execution to a passed-in function. 
    244 
    245    """ 
    246 
    247    name = "mock" 
    248 
    249    def create(self, name_or_url, executor, **kwargs): 
    250        # create url.URL object 
    251        u = url.make_url(name_or_url) 
    252 
    253        dialect_cls = u.get_dialect() 
    254 
    255        dialect_args = {} 
    256        # consume dialect arguments from kwargs 
    257        for k in util.get_cls_kwargs(dialect_cls): 
    258            if k in kwargs: 
    259                dialect_args[k] = kwargs.pop(k) 
    260 
    261        # create dialect 
    262        dialect = dialect_cls(**dialect_args) 
    263 
    264        return MockEngineStrategy.MockConnection(dialect, executor) 
    265 
    266    class MockConnection(base.Connectable): 
    267        def __init__(self, dialect, execute): 
    268            self._dialect = dialect 
    269            self.execute = execute 
    270 
    271        engine = property(lambda s: s) 
    272        dialect = property(attrgetter("_dialect")) 
    273        name = property(lambda s: s._dialect.name) 
    274 
    275        schema_for_object = schema._schema_getter(None) 
    276 
    277        def contextual_connect(self, **kwargs): 
    278            return self 
    279 
    280        def connect(self, **kwargs): 
    281            return self 
    282 
    283        def execution_options(self, **kw): 
    284            return self 
    285 
    286        def compiler(self, statement, parameters, **kwargs): 
    287            return self._dialect.compiler( 
    288                statement, parameters, engine=self, **kwargs 
    289            ) 
    290 
    291        def create(self, entity, **kwargs): 
    292            kwargs["checkfirst"] = False 
    293            from sqlalchemy.engine import ddl 
    294 
    295            ddl.SchemaGenerator(self.dialect, self, **kwargs).traverse_single( 
    296                entity 
    297            ) 
    298 
    299        def drop(self, entity, **kwargs): 
    300            kwargs["checkfirst"] = False 
    301            from sqlalchemy.engine import ddl 
    302 
    303            ddl.SchemaDropper(self.dialect, self, **kwargs).traverse_single( 
    304                entity 
    305            ) 
    306 
    307        def _run_visitor( 
    308            self, visitorcallable, element, connection=None, **kwargs 
    309        ): 
    310            kwargs["checkfirst"] = False 
    311            visitorcallable(self.dialect, self, **kwargs).traverse_single( 
    312                element 
    313            ) 
    314 
    315        def execute(self, object_, *multiparams, **params): 
    316            raise NotImplementedError() 
    317 
    318 
    319MockEngineStrategy()