1# engine/threadlocal.py 
    2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: http://www.opensource.org/licenses/mit-license.php 
    7 
    8"""Provides a thread-local transactional wrapper around the root Engine class. 
    9 
    10The ``threadlocal`` module is invoked when using the 
    11``strategy="threadlocal"`` flag with :func:`~sqlalchemy.engine.create_engine`. 
    12This module is semi-private and is invoked automatically when the threadlocal 
    13engine strategy is used. 
    14""" 
    15 
    16import weakref 
    17 
    18from . import base 
    19from .. import util 
    20 
    21 
    22class TLConnection(base.Connection): 
    23    def __init__(self, *arg, **kw): 
    24        super(TLConnection, self).__init__(*arg, **kw) 
    25        self.__opencount = 0 
    26 
    27    def _increment_connect(self): 
    28        self.__opencount += 1 
    29        return self 
    30 
    31    def close(self): 
    32        if self.__opencount == 1: 
    33            base.Connection.close(self) 
    34        self.__opencount -= 1 
    35 
    36    def _force_close(self): 
    37        self.__opencount = 0 
    38        base.Connection.close(self) 
    39 
    40 
    41class TLEngine(base.Engine): 
    42    """An Engine that includes support for thread-local managed 
    43    transactions. 
    44 
    45    """ 
    46 
    47    _tl_connection_cls = TLConnection 
    48 
    49    @util.deprecated( 
    50        "1.3", 
    51        "The 'threadlocal' engine strategy is deprecated, and will be " 
    52        "removed in a future release.  The strategy is no longer relevant " 
    53        "to modern usage patterns (including that of the ORM " 
    54        ":class:`.Session` object) which make use of a " 
    55        ":class:`_engine.Connection` " 
    56        "object in order to invoke statements.", 
    57    ) 
    58    def __init__(self, *args, **kwargs): 
    59        super(TLEngine, self).__init__(*args, **kwargs) 
    60        self._connections = util.threading.local() 
    61 
    62    def contextual_connect(self, **kw): 
    63        return self._contextual_connect(**kw) 
    64 
    65    def _contextual_connect(self, **kw): 
    66        if not hasattr(self._connections, "conn"): 
    67            connection = None 
    68        else: 
    69            connection = self._connections.conn() 
    70 
    71        if connection is None or connection.closed: 
    72            # guards against pool-level reapers, if desired. 
    73            # or not connection.connection.is_valid: 
    74            connection = self._tl_connection_cls( 
    75                self, 
    76                self._wrap_pool_connect(self.pool.connect, connection), 
    77                **kw 
    78            ) 
    79            self._connections.conn = weakref.ref(connection) 
    80 
    81        return connection._increment_connect() 
    82 
    83    def begin_twophase(self, xid=None): 
    84        if not hasattr(self._connections, "trans"): 
    85            self._connections.trans = [] 
    86        self._connections.trans.append( 
    87            self._contextual_connect().begin_twophase(xid=xid) 
    88        ) 
    89        return self 
    90 
    91    def begin_nested(self): 
    92        if not hasattr(self._connections, "trans"): 
    93            self._connections.trans = [] 
    94        self._connections.trans.append( 
    95            self._contextual_connect().begin_nested() 
    96        ) 
    97        return self 
    98 
    99    def begin(self): 
    100        if not hasattr(self._connections, "trans"): 
    101            self._connections.trans = [] 
    102        self._connections.trans.append(self._contextual_connect().begin()) 
    103        return self 
    104 
    105    def __enter__(self): 
    106        return self 
    107 
    108    def __exit__(self, type_, value, traceback): 
    109        if type_ is None: 
    110            self.commit() 
    111        else: 
    112            self.rollback() 
    113 
    114    def prepare(self): 
    115        if ( 
    116            not hasattr(self._connections, "trans") 
    117            or not self._connections.trans 
    118        ): 
    119            return 
    120        self._connections.trans[-1].prepare() 
    121 
    122    def commit(self): 
    123        if ( 
    124            not hasattr(self._connections, "trans") 
    125            or not self._connections.trans 
    126        ): 
    127            return 
    128        trans = self._connections.trans.pop(-1) 
    129        trans.commit() 
    130 
    131    def rollback(self): 
    132        if ( 
    133            not hasattr(self._connections, "trans") 
    134            or not self._connections.trans 
    135        ): 
    136            return 
    137        trans = self._connections.trans.pop(-1) 
    138        trans.rollback() 
    139 
    140    def dispose(self): 
    141        self._connections = util.threading.local() 
    142        super(TLEngine, self).dispose() 
    143 
    144    @property 
    145    def closed(self): 
    146        return ( 
    147            not hasattr(self._connections, "conn") 
    148            or self._connections.conn() is None 
    149            or self._connections.conn().closed 
    150        ) 
    151 
    152    def close(self): 
    153        if not self.closed: 
    154            self._contextual_connect().close() 
    155            connection = self._connections.conn() 
    156            connection._force_close() 
    157            del self._connections.conn 
    158            self._connections.trans = [] 
    159 
    160    def __repr__(self): 
    161        return "TLEngine(%r)" % self.url