1# sql/visitors.py 
    2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: http://www.opensource.org/licenses/mit-license.php 
    7 
    8"""Visitor/traversal interface and library functions. 
    9 
    10SQLAlchemy schema and expression constructs rely on a Python-centric 
    11version of the classic "visitor" pattern as the primary way in which 
    12they apply functionality.  The most common use of this pattern 
    13is statement compilation, where individual expression classes match 
    14up to rendering methods that produce a string result.   Beyond this, 
    15the visitor system is also used to inspect expressions for various 
    16information and patterns, as well as for the purposes of applying 
    17transformations to expressions. 
    18 
    19Examples of how the visit system is used can be seen in the source code 
    20of for example the ``sqlalchemy.sql.util`` and the ``sqlalchemy.sql.compiler`` 
    21modules.  Some background on clause adaption is also at 
    22http://techspot.zzzeek.org/2008/01/23/expression-transformations/ . 
    23 
    24""" 
    25 
    26from collections import deque 
    27import operator 
    28 
    29from .. import exc 
    30from .. import util 
    31 
    32 
    33__all__ = [ 
    34    "VisitableType", 
    35    "Visitable", 
    36    "ClauseVisitor", 
    37    "CloningVisitor", 
    38    "ReplacingCloningVisitor", 
    39    "iterate", 
    40    "iterate_depthfirst", 
    41    "traverse_using", 
    42    "traverse", 
    43    "traverse_depthfirst", 
    44    "cloned_traverse", 
    45    "replacement_traverse", 
    46] 
    47 
    48 
    49class VisitableType(type): 
    50    """Metaclass which assigns a ``_compiler_dispatch`` method to classes 
    51    having a ``__visit_name__`` attribute. 
    52 
    53    The ``_compiler_dispatch`` attribute becomes an instance method which 
    54    looks approximately like the following:: 
    55 
    56        def _compiler_dispatch (self, visitor, **kw): 
    57            '''Look for an attribute named "visit_" + self.__visit_name__ 
    58            on the visitor, and call it with the same kw params.''' 
    59            visit_attr = 'visit_%s' % self.__visit_name__ 
    60            return getattr(visitor, visit_attr)(self, **kw) 
    61 
    62    Classes having no ``__visit_name__`` attribute will remain unaffected. 
    63 
    64    """ 
    65 
    66    def __init__(cls, clsname, bases, clsdict): 
    67        if clsname != "Visitable" and hasattr(cls, "__visit_name__"): 
    68            _generate_dispatch(cls) 
    69 
    70        super(VisitableType, cls).__init__(clsname, bases, clsdict) 
    71 
    72 
    73def _generate_dispatch(cls): 
    74    """Return an optimized visit dispatch function for the cls 
    75    for use by the compiler. 
    76 
    77    """ 
    78    if "__visit_name__" in cls.__dict__: 
    79        visit_name = cls.__visit_name__ 
    80 
    81        if isinstance(visit_name, util.compat.string_types): 
    82            # There is an optimization opportunity here because the 
    83            # the string name of the class's __visit_name__ is known at 
    84            # this early stage (import time) so it can be pre-constructed. 
    85            getter = operator.attrgetter("visit_%s" % visit_name) 
    86 
    87            def _compiler_dispatch(self, visitor, **kw): 
    88                try: 
    89                    meth = getter(visitor) 
    90                except AttributeError as err: 
    91                    util.raise_( 
    92                        exc.UnsupportedCompilationError(visitor, cls), 
    93                        replace_context=err, 
    94                    ) 
    95                else: 
    96                    return meth(self, **kw) 
    97 
    98        else: 
    99            # The optimization opportunity is lost for this case because the 
    100            # __visit_name__ is not yet a string. As a result, the visit 
    101            # string has to be recalculated with each compilation. 
    102            def _compiler_dispatch(self, visitor, **kw): 
    103                visit_attr = "visit_%s" % self.__visit_name__ 
    104                try: 
    105                    meth = getattr(visitor, visit_attr) 
    106                except AttributeError as err: 
    107                    util.raise_( 
    108                        exc.UnsupportedCompilationError(visitor, cls), 
    109                        replace_context=err, 
    110                    ) 
    111                else: 
    112                    return meth(self, **kw) 
    113 
    114        _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__ 
    115            on the visitor, and call it with the same kw params. 
    116            """ 
    117        cls._compiler_dispatch = _compiler_dispatch 
    118 
    119 
    120class Visitable(util.with_metaclass(VisitableType, object)): 
    121    """Base class for visitable objects, applies the 
    122    :class:`.visitors.VisitableType` metaclass. 
    123 
    124    The :class:`.Visitable` class is essentially at the base of the 
    125    :class:`_expression.ClauseElement` hierarchy. 
    126 
    127    """ 
    128 
    129 
    130class ClauseVisitor(object): 
    131    """Base class for visitor objects which can traverse using 
    132    the :func:`.visitors.traverse` function. 
    133 
    134    Direct usage of the :func:`.visitors.traverse` function is usually 
    135    preferred. 
    136 
    137    """ 
    138 
    139    __traverse_options__ = {} 
    140 
    141    def traverse_single(self, obj, **kw): 
    142        for v in self.visitor_iterator: 
    143            meth = getattr(v, "visit_%s" % obj.__visit_name__, None) 
    144            if meth: 
    145                return meth(obj, **kw) 
    146 
    147    def iterate(self, obj): 
    148        """Traverse the given expression structure, returning an iterator 
    149        of all elements. 
    150 
    151        """ 
    152        return iterate(obj, self.__traverse_options__) 
    153 
    154    def traverse(self, obj): 
    155        """Traverse and visit the given expression structure.""" 
    156 
    157        return traverse(obj, self.__traverse_options__, self._visitor_dict) 
    158 
    159    @util.memoized_property 
    160    def _visitor_dict(self): 
    161        visitors = {} 
    162 
    163        for name in dir(self): 
    164            if name.startswith("visit_"): 
    165                visitors[name[6:]] = getattr(self, name) 
    166        return visitors 
    167 
    168    @property 
    169    def visitor_iterator(self): 
    170        """Iterate through this visitor and each 'chained' visitor.""" 
    171 
    172        v = self 
    173        while v: 
    174            yield v 
    175            v = getattr(v, "_next", None) 
    176 
    177    def chain(self, visitor): 
    178        """'Chain' an additional ClauseVisitor onto this ClauseVisitor. 
    179 
    180        The chained visitor will receive all visit events after this one. 
    181 
    182        """ 
    183        tail = list(self.visitor_iterator)[-1] 
    184        tail._next = visitor 
    185        return self 
    186 
    187 
    188class CloningVisitor(ClauseVisitor): 
    189    """Base class for visitor objects which can traverse using 
    190    the :func:`.visitors.cloned_traverse` function. 
    191 
    192    Direct usage of the :func:`.visitors.cloned_traverse` function is usually 
    193    preferred. 
    194 
    195 
    196    """ 
    197 
    198    def copy_and_process(self, list_): 
    199        """Apply cloned traversal to the given list of elements, and return 
    200        the new list. 
    201 
    202        """ 
    203        return [self.traverse(x) for x in list_] 
    204 
    205    def traverse(self, obj): 
    206        """Traverse and visit the given expression structure.""" 
    207 
    208        return cloned_traverse( 
    209            obj, self.__traverse_options__, self._visitor_dict 
    210        ) 
    211 
    212 
    213class ReplacingCloningVisitor(CloningVisitor): 
    214    """Base class for visitor objects which can traverse using 
    215    the :func:`.visitors.replacement_traverse` function. 
    216 
    217    Direct usage of the :func:`.visitors.replacement_traverse` function is 
    218    usually preferred. 
    219 
    220    """ 
    221 
    222    def replace(self, elem): 
    223        """Receive pre-copied elements during a cloning traversal. 
    224 
    225        If the method returns a new element, the element is used 
    226        instead of creating a simple copy of the element.  Traversal 
    227        will halt on the newly returned element if it is re-encountered. 
    228        """ 
    229        return None 
    230 
    231    def traverse(self, obj): 
    232        """Traverse and visit the given expression structure.""" 
    233 
    234        def replace(elem): 
    235            for v in self.visitor_iterator: 
    236                e = v.replace(elem) 
    237                if e is not None: 
    238                    return e 
    239 
    240        return replacement_traverse(obj, self.__traverse_options__, replace) 
    241 
    242 
    243def iterate(obj, opts): 
    244    r"""Traverse the given expression structure, returning an iterator. 
    245 
    246    Traversal is configured to be breadth-first. 
    247 
    248    The central API feature used by the :func:`.visitors.iterate` and 
    249    :func:`.visitors.iterate_depthfirst` functions is the 
    250    :meth:`_expression.ClauseElement.get_children` method of 
    251    :class:`_expression.ClauseElement` objects.  This method should return all 
    252    the :class:`_expression.ClauseElement` objects which are associated with a 
    253    particular :class:`_expression.ClauseElement` object. For example, a 
    254    :class:`.Case` structure will refer to a series of 
    255    :class:`_expression.ColumnElement` objects within its "whens" and "else\_" 
    256    member variables. 
    257 
    258    :param obj: :class:`_expression.ClauseElement` structure to be traversed 
    259 
    260    :param opts: dictionary of iteration options.   This dictionary is usually 
    261     empty in modern usage. 
    262 
    263    """ 
    264    # fasttrack for atomic elements like columns 
    265    children = obj.get_children(**opts) 
    266    if not children: 
    267        return [obj] 
    268 
    269    traversal = deque() 
    270    stack = deque([obj]) 
    271    while stack: 
    272        t = stack.popleft() 
    273        traversal.append(t) 
    274        for c in t.get_children(**opts): 
    275            stack.append(c) 
    276    return iter(traversal) 
    277 
    278 
    279def iterate_depthfirst(obj, opts): 
    280    """Traverse the given expression structure, returning an iterator. 
    281 
    282    Traversal is configured to be depth-first. 
    283 
    284    :param obj: :class:`_expression.ClauseElement` structure to be traversed 
    285 
    286    :param opts: dictionary of iteration options.   This dictionary is usually 
    287     empty in modern usage. 
    288 
    289    .. seealso:: 
    290 
    291        :func:`.visitors.iterate` - includes a general overview of iteration. 
    292 
    293    """ 
    294    # fasttrack for atomic elements like columns 
    295    children = obj.get_children(**opts) 
    296    if not children: 
    297        return [obj] 
    298 
    299    stack = deque([obj]) 
    300    traversal = deque() 
    301    while stack: 
    302        t = stack.pop() 
    303        traversal.appendleft(t) 
    304        for c in t.get_children(**opts): 
    305            stack.append(c) 
    306    return iter(traversal) 
    307 
    308 
    309def traverse_using(iterator, obj, visitors): 
    310    """Visit the given expression structure using the given iterator of 
    311    objects. 
    312 
    313    :func:`.visitors.traverse_using` is usually called internally as the result 
    314    of the :func:`.visitors.traverse` or :func:`.visitors.traverse_depthfirst` 
    315    functions. 
    316 
    317    :param iterator: an iterable or sequence which will yield 
    318     :class:`_expression.ClauseElement` 
    319     structures; the iterator is assumed to be the 
    320     product of the :func:`.visitors.iterate` or 
    321     :func:`.visitors.iterate_depthfirst` functions. 
    322 
    323    :param obj: the :class:`_expression.ClauseElement` 
    324     that was used as the target of the 
    325     :func:`.iterate` or :func:`.iterate_depthfirst` function. 
    326 
    327    :param visitors: dictionary of visit functions.  See :func:`.traverse` 
    328     for details on this dictionary. 
    329 
    330    .. seealso:: 
    331 
    332        :func:`.traverse` 
    333 
    334        :func:`.traverse_depthfirst` 
    335 
    336    """ 
    337    for target in iterator: 
    338        meth = visitors.get(target.__visit_name__, None) 
    339        if meth: 
    340            meth(target) 
    341    return obj 
    342 
    343 
    344def traverse(obj, opts, visitors): 
    345    """Traverse and visit the given expression structure using the default 
    346    iterator. 
    347 
    348     e.g.:: 
    349 
    350        from sqlalchemy.sql import visitors 
    351 
    352        stmt = select([some_table]).where(some_table.c.foo == 'bar') 
    353 
    354        def visit_bindparam(bind_param): 
    355            print("found bound value: %s" % bind_param.value) 
    356 
    357        visitors.traverse(stmt, {}, {"bindparam": visit_bindparam}) 
    358 
    359    The iteration of objects uses the :func:`.visitors.iterate` function, 
    360    which does a breadth-first traversal using a stack. 
    361 
    362    :param obj: :class:`_expression.ClauseElement` structure to be traversed 
    363 
    364    :param opts: dictionary of iteration options.   This dictionary is usually 
    365     empty in modern usage. 
    366 
    367    :param visitors: dictionary of visit functions.   The dictionary should 
    368     have strings as keys, each of which would correspond to the 
    369     ``__visit_name__`` of a particular kind of SQL expression object, and 
    370     callable functions  as values, each of which represents a visitor function 
    371     for that kind of object. 
    372 
    373    """ 
    374    return traverse_using(iterate(obj, opts), obj, visitors) 
    375 
    376 
    377def traverse_depthfirst(obj, opts, visitors): 
    378    """traverse and visit the given expression structure using the 
    379    depth-first iterator. 
    380 
    381    The iteration of objects uses the :func:`.visitors.iterate_depthfirst` 
    382    function, which does a depth-first traversal using a stack. 
    383 
    384    Usage is the same as that of :func:`.visitors.traverse` function. 
    385 
    386 
    387    """ 
    388    return traverse_using(iterate_depthfirst(obj, opts), obj, visitors) 
    389 
    390 
    391def cloned_traverse(obj, opts, visitors): 
    392    """Clone the given expression structure, allowing modifications by 
    393    visitors. 
    394 
    395    Traversal usage is the same as that of :func:`.visitors.traverse`. 
    396    The visitor functions present in the ``visitors`` dictionary may also 
    397    modify the internals of the given structure as the traversal proceeds. 
    398 
    399    The central API feature used by the :func:`.visitors.cloned_traverse` 
    400    and :func:`.visitors.replacement_traverse` functions, in addition to the 
    401    :meth:`_expression.ClauseElement.get_children` 
    402    function that is used to achieve 
    403    the iteration, is the :meth:`_expression.ClauseElement._copy_internals` 
    404    method. 
    405    For a :class:`_expression.ClauseElement` 
    406    structure to support cloning and replacement 
    407    traversals correctly, it needs to be able to pass a cloning function into 
    408    its internal members in order to make copies of them. 
    409 
    410    .. seealso:: 
    411 
    412        :func:`.visitors.traverse` 
    413 
    414        :func:`.visitors.replacement_traverse` 
    415 
    416    """ 
    417 
    418    cloned = {} 
    419    stop_on = set(opts.get("stop_on", [])) 
    420 
    421    def clone(elem, **kw): 
    422        if elem in stop_on: 
    423            return elem 
    424        else: 
    425            if id(elem) not in cloned: 
    426                cloned[id(elem)] = newelem = elem._clone() 
    427                newelem._copy_internals(clone=clone, **kw) 
    428                meth = visitors.get(newelem.__visit_name__, None) 
    429                if meth: 
    430                    meth(newelem) 
    431            return cloned[id(elem)] 
    432 
    433    if obj is not None: 
    434        obj = clone(obj) 
    435    clone = None  # remove gc cycles 
    436    return obj 
    437 
    438 
    439def replacement_traverse(obj, opts, replace): 
    440    """Clone the given expression structure, allowing element 
    441    replacement by a given replacement function. 
    442 
    443    This function is very similar to the :func:`.visitors.cloned_traverse` 
    444    function, except instead of being passed a dictionary of visitors, all 
    445    elements are unconditionally passed into the given replace function. 
    446    The replace function then has the option to return an entirely new object 
    447    which will replace the one given.  If it returns ``None``, then the object 
    448    is kept in place. 
    449 
    450    The difference in usage between :func:`.visitors.cloned_traverse` and 
    451    :func:`.visitors.replacement_traverse` is that in the former case, an 
    452    already-cloned object is passed to the visitor function, and the visitor 
    453    function can then manipulate the internal state of the object. 
    454    In the case of the latter, the visitor function should only return an 
    455    entirely different object, or do nothing. 
    456 
    457    The use case for :func:`.visitors.replacement_traverse` is that of 
    458    replacing a FROM clause inside of a SQL structure with a different one, 
    459    as is a common use case within the ORM. 
    460 
    461    """ 
    462 
    463    cloned = {} 
    464    stop_on = {id(x) for x in opts.get("stop_on", [])} 
    465 
    466    def clone(elem, **kw): 
    467        if ( 
    468            id(elem) in stop_on 
    469            or "no_replacement_traverse" in elem._annotations 
    470        ): 
    471            return elem 
    472        else: 
    473            newelem = replace(elem) 
    474            if newelem is not None: 
    475                stop_on.add(id(newelem)) 
    476                return newelem 
    477            else: 
    478                if elem not in cloned: 
    479                    cloned[elem] = newelem = elem._clone() 
    480                    newelem._copy_internals(clone=clone, **kw) 
    481                return cloned[elem] 
    482 
    483    if obj is not None: 
    484        obj = clone(obj, **opts) 
    485    clone = None  # remove gc cycles 
    486    return obj