1# sql/visitors.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Visitor/traversal interface and library functions.
9
10SQLAlchemy schema and expression constructs rely on a Python-centric
11version of the classic "visitor" pattern as the primary way in which
12they apply functionality. The most common use of this pattern
13is statement compilation, where individual expression classes match
14up to rendering methods that produce a string result. Beyond this,
15the visitor system is also used to inspect expressions for various
16information and patterns, as well as for the purposes of applying
17transformations to expressions.
18
19Examples of how the visit system is used can be seen in the source code
20of for example the ``sqlalchemy.sql.util`` and the ``sqlalchemy.sql.compiler``
21modules. Some background on clause adaption is also at
22http://techspot.zzzeek.org/2008/01/23/expression-transformations/ .
23
24"""
25
26from collections import deque
27import operator
28
29from .. import exc
30from .. import util
31
32
33__all__ = [
34 "VisitableType",
35 "Visitable",
36 "ClauseVisitor",
37 "CloningVisitor",
38 "ReplacingCloningVisitor",
39 "iterate",
40 "iterate_depthfirst",
41 "traverse_using",
42 "traverse",
43 "traverse_depthfirst",
44 "cloned_traverse",
45 "replacement_traverse",
46]
47
48
49class VisitableType(type):
50 """Metaclass which assigns a ``_compiler_dispatch`` method to classes
51 having a ``__visit_name__`` attribute.
52
53 The ``_compiler_dispatch`` attribute becomes an instance method which
54 looks approximately like the following::
55
56 def _compiler_dispatch (self, visitor, **kw):
57 '''Look for an attribute named "visit_" + self.__visit_name__
58 on the visitor, and call it with the same kw params.'''
59 visit_attr = 'visit_%s' % self.__visit_name__
60 return getattr(visitor, visit_attr)(self, **kw)
61
62 Classes having no ``__visit_name__`` attribute will remain unaffected.
63
64 """
65
66 def __init__(cls, clsname, bases, clsdict):
67 if clsname != "Visitable" and hasattr(cls, "__visit_name__"):
68 _generate_dispatch(cls)
69
70 super(VisitableType, cls).__init__(clsname, bases, clsdict)
71
72
73def _generate_dispatch(cls):
74 """Return an optimized visit dispatch function for the cls
75 for use by the compiler.
76
77 """
78 if "__visit_name__" in cls.__dict__:
79 visit_name = cls.__visit_name__
80
81 if isinstance(visit_name, util.compat.string_types):
82 # There is an optimization opportunity here because the
83 # the string name of the class's __visit_name__ is known at
84 # this early stage (import time) so it can be pre-constructed.
85 getter = operator.attrgetter("visit_%s" % visit_name)
86
87 def _compiler_dispatch(self, visitor, **kw):
88 try:
89 meth = getter(visitor)
90 except AttributeError as err:
91 util.raise_(
92 exc.UnsupportedCompilationError(visitor, cls),
93 replace_context=err,
94 )
95 else:
96 return meth(self, **kw)
97
98 else:
99 # The optimization opportunity is lost for this case because the
100 # __visit_name__ is not yet a string. As a result, the visit
101 # string has to be recalculated with each compilation.
102 def _compiler_dispatch(self, visitor, **kw):
103 visit_attr = "visit_%s" % self.__visit_name__
104 try:
105 meth = getattr(visitor, visit_attr)
106 except AttributeError as err:
107 util.raise_(
108 exc.UnsupportedCompilationError(visitor, cls),
109 replace_context=err,
110 )
111 else:
112 return meth(self, **kw)
113
114 _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__
115 on the visitor, and call it with the same kw params.
116 """
117 cls._compiler_dispatch = _compiler_dispatch
118
119
120class Visitable(util.with_metaclass(VisitableType, object)):
121 """Base class for visitable objects, applies the
122 :class:`.visitors.VisitableType` metaclass.
123
124 The :class:`.Visitable` class is essentially at the base of the
125 :class:`_expression.ClauseElement` hierarchy.
126
127 """
128
129
130class ClauseVisitor(object):
131 """Base class for visitor objects which can traverse using
132 the :func:`.visitors.traverse` function.
133
134 Direct usage of the :func:`.visitors.traverse` function is usually
135 preferred.
136
137 """
138
139 __traverse_options__ = {}
140
141 def traverse_single(self, obj, **kw):
142 for v in self.visitor_iterator:
143 meth = getattr(v, "visit_%s" % obj.__visit_name__, None)
144 if meth:
145 return meth(obj, **kw)
146
147 def iterate(self, obj):
148 """Traverse the given expression structure, returning an iterator
149 of all elements.
150
151 """
152 return iterate(obj, self.__traverse_options__)
153
154 def traverse(self, obj):
155 """Traverse and visit the given expression structure."""
156
157 return traverse(obj, self.__traverse_options__, self._visitor_dict)
158
159 @util.memoized_property
160 def _visitor_dict(self):
161 visitors = {}
162
163 for name in dir(self):
164 if name.startswith("visit_"):
165 visitors[name[6:]] = getattr(self, name)
166 return visitors
167
168 @property
169 def visitor_iterator(self):
170 """Iterate through this visitor and each 'chained' visitor."""
171
172 v = self
173 while v:
174 yield v
175 v = getattr(v, "_next", None)
176
177 def chain(self, visitor):
178 """'Chain' an additional ClauseVisitor onto this ClauseVisitor.
179
180 The chained visitor will receive all visit events after this one.
181
182 """
183 tail = list(self.visitor_iterator)[-1]
184 tail._next = visitor
185 return self
186
187
188class CloningVisitor(ClauseVisitor):
189 """Base class for visitor objects which can traverse using
190 the :func:`.visitors.cloned_traverse` function.
191
192 Direct usage of the :func:`.visitors.cloned_traverse` function is usually
193 preferred.
194
195
196 """
197
198 def copy_and_process(self, list_):
199 """Apply cloned traversal to the given list of elements, and return
200 the new list.
201
202 """
203 return [self.traverse(x) for x in list_]
204
205 def traverse(self, obj):
206 """Traverse and visit the given expression structure."""
207
208 return cloned_traverse(
209 obj, self.__traverse_options__, self._visitor_dict
210 )
211
212
213class ReplacingCloningVisitor(CloningVisitor):
214 """Base class for visitor objects which can traverse using
215 the :func:`.visitors.replacement_traverse` function.
216
217 Direct usage of the :func:`.visitors.replacement_traverse` function is
218 usually preferred.
219
220 """
221
222 def replace(self, elem):
223 """Receive pre-copied elements during a cloning traversal.
224
225 If the method returns a new element, the element is used
226 instead of creating a simple copy of the element. Traversal
227 will halt on the newly returned element if it is re-encountered.
228 """
229 return None
230
231 def traverse(self, obj):
232 """Traverse and visit the given expression structure."""
233
234 def replace(elem):
235 for v in self.visitor_iterator:
236 e = v.replace(elem)
237 if e is not None:
238 return e
239
240 return replacement_traverse(obj, self.__traverse_options__, replace)
241
242
243def iterate(obj, opts):
244 r"""Traverse the given expression structure, returning an iterator.
245
246 Traversal is configured to be breadth-first.
247
248 The central API feature used by the :func:`.visitors.iterate` and
249 :func:`.visitors.iterate_depthfirst` functions is the
250 :meth:`_expression.ClauseElement.get_children` method of
251 :class:`_expression.ClauseElement` objects. This method should return all
252 the :class:`_expression.ClauseElement` objects which are associated with a
253 particular :class:`_expression.ClauseElement` object. For example, a
254 :class:`.Case` structure will refer to a series of
255 :class:`_expression.ColumnElement` objects within its "whens" and "else\_"
256 member variables.
257
258 :param obj: :class:`_expression.ClauseElement` structure to be traversed
259
260 :param opts: dictionary of iteration options. This dictionary is usually
261 empty in modern usage.
262
263 """
264 # fasttrack for atomic elements like columns
265 children = obj.get_children(**opts)
266 if not children:
267 return [obj]
268
269 traversal = deque()
270 stack = deque([obj])
271 while stack:
272 t = stack.popleft()
273 traversal.append(t)
274 for c in t.get_children(**opts):
275 stack.append(c)
276 return iter(traversal)
277
278
279def iterate_depthfirst(obj, opts):
280 """Traverse the given expression structure, returning an iterator.
281
282 Traversal is configured to be depth-first.
283
284 :param obj: :class:`_expression.ClauseElement` structure to be traversed
285
286 :param opts: dictionary of iteration options. This dictionary is usually
287 empty in modern usage.
288
289 .. seealso::
290
291 :func:`.visitors.iterate` - includes a general overview of iteration.
292
293 """
294 # fasttrack for atomic elements like columns
295 children = obj.get_children(**opts)
296 if not children:
297 return [obj]
298
299 stack = deque([obj])
300 traversal = deque()
301 while stack:
302 t = stack.pop()
303 traversal.appendleft(t)
304 for c in t.get_children(**opts):
305 stack.append(c)
306 return iter(traversal)
307
308
309def traverse_using(iterator, obj, visitors):
310 """Visit the given expression structure using the given iterator of
311 objects.
312
313 :func:`.visitors.traverse_using` is usually called internally as the result
314 of the :func:`.visitors.traverse` or :func:`.visitors.traverse_depthfirst`
315 functions.
316
317 :param iterator: an iterable or sequence which will yield
318 :class:`_expression.ClauseElement`
319 structures; the iterator is assumed to be the
320 product of the :func:`.visitors.iterate` or
321 :func:`.visitors.iterate_depthfirst` functions.
322
323 :param obj: the :class:`_expression.ClauseElement`
324 that was used as the target of the
325 :func:`.iterate` or :func:`.iterate_depthfirst` function.
326
327 :param visitors: dictionary of visit functions. See :func:`.traverse`
328 for details on this dictionary.
329
330 .. seealso::
331
332 :func:`.traverse`
333
334 :func:`.traverse_depthfirst`
335
336 """
337 for target in iterator:
338 meth = visitors.get(target.__visit_name__, None)
339 if meth:
340 meth(target)
341 return obj
342
343
344def traverse(obj, opts, visitors):
345 """Traverse and visit the given expression structure using the default
346 iterator.
347
348 e.g.::
349
350 from sqlalchemy.sql import visitors
351
352 stmt = select([some_table]).where(some_table.c.foo == 'bar')
353
354 def visit_bindparam(bind_param):
355 print("found bound value: %s" % bind_param.value)
356
357 visitors.traverse(stmt, {}, {"bindparam": visit_bindparam})
358
359 The iteration of objects uses the :func:`.visitors.iterate` function,
360 which does a breadth-first traversal using a stack.
361
362 :param obj: :class:`_expression.ClauseElement` structure to be traversed
363
364 :param opts: dictionary of iteration options. This dictionary is usually
365 empty in modern usage.
366
367 :param visitors: dictionary of visit functions. The dictionary should
368 have strings as keys, each of which would correspond to the
369 ``__visit_name__`` of a particular kind of SQL expression object, and
370 callable functions as values, each of which represents a visitor function
371 for that kind of object.
372
373 """
374 return traverse_using(iterate(obj, opts), obj, visitors)
375
376
377def traverse_depthfirst(obj, opts, visitors):
378 """traverse and visit the given expression structure using the
379 depth-first iterator.
380
381 The iteration of objects uses the :func:`.visitors.iterate_depthfirst`
382 function, which does a depth-first traversal using a stack.
383
384 Usage is the same as that of :func:`.visitors.traverse` function.
385
386
387 """
388 return traverse_using(iterate_depthfirst(obj, opts), obj, visitors)
389
390
391def cloned_traverse(obj, opts, visitors):
392 """Clone the given expression structure, allowing modifications by
393 visitors.
394
395 Traversal usage is the same as that of :func:`.visitors.traverse`.
396 The visitor functions present in the ``visitors`` dictionary may also
397 modify the internals of the given structure as the traversal proceeds.
398
399 The central API feature used by the :func:`.visitors.cloned_traverse`
400 and :func:`.visitors.replacement_traverse` functions, in addition to the
401 :meth:`_expression.ClauseElement.get_children`
402 function that is used to achieve
403 the iteration, is the :meth:`_expression.ClauseElement._copy_internals`
404 method.
405 For a :class:`_expression.ClauseElement`
406 structure to support cloning and replacement
407 traversals correctly, it needs to be able to pass a cloning function into
408 its internal members in order to make copies of them.
409
410 .. seealso::
411
412 :func:`.visitors.traverse`
413
414 :func:`.visitors.replacement_traverse`
415
416 """
417
418 cloned = {}
419 stop_on = set(opts.get("stop_on", []))
420
421 def clone(elem, **kw):
422 if elem in stop_on:
423 return elem
424 else:
425 if id(elem) not in cloned:
426 cloned[id(elem)] = newelem = elem._clone()
427 newelem._copy_internals(clone=clone, **kw)
428 meth = visitors.get(newelem.__visit_name__, None)
429 if meth:
430 meth(newelem)
431 return cloned[id(elem)]
432
433 if obj is not None:
434 obj = clone(obj)
435 clone = None # remove gc cycles
436 return obj
437
438
439def replacement_traverse(obj, opts, replace):
440 """Clone the given expression structure, allowing element
441 replacement by a given replacement function.
442
443 This function is very similar to the :func:`.visitors.cloned_traverse`
444 function, except instead of being passed a dictionary of visitors, all
445 elements are unconditionally passed into the given replace function.
446 The replace function then has the option to return an entirely new object
447 which will replace the one given. If it returns ``None``, then the object
448 is kept in place.
449
450 The difference in usage between :func:`.visitors.cloned_traverse` and
451 :func:`.visitors.replacement_traverse` is that in the former case, an
452 already-cloned object is passed to the visitor function, and the visitor
453 function can then manipulate the internal state of the object.
454 In the case of the latter, the visitor function should only return an
455 entirely different object, or do nothing.
456
457 The use case for :func:`.visitors.replacement_traverse` is that of
458 replacing a FROM clause inside of a SQL structure with a different one,
459 as is a common use case within the ORM.
460
461 """
462
463 cloned = {}
464 stop_on = {id(x) for x in opts.get("stop_on", [])}
465
466 def clone(elem, **kw):
467 if (
468 id(elem) in stop_on
469 or "no_replacement_traverse" in elem._annotations
470 ):
471 return elem
472 else:
473 newelem = replace(elem)
474 if newelem is not None:
475 stop_on.add(id(newelem))
476 return newelem
477 else:
478 if elem not in cloned:
479 cloned[elem] = newelem = elem._clone()
480 newelem._copy_internals(clone=clone, **kw)
481 return cloned[elem]
482
483 if obj is not None:
484 obj = clone(obj, **opts)
485 clone = None # remove gc cycles
486 return obj