1# sql/traversals.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7from collections import deque
8from collections import namedtuple
9import itertools
10import operator
11
12from . import operators
13from .visitors import ExtendedInternalTraversal
14from .visitors import InternalTraversal
15from .. import util
16from ..inspection import inspect
17from ..util import collections_abc
18from ..util import HasMemoized
19from ..util import py37
20
21SKIP_TRAVERSE = util.symbol("skip_traverse")
22COMPARE_FAILED = False
23COMPARE_SUCCEEDED = True
24NO_CACHE = util.symbol("no_cache")
25CACHE_IN_PLACE = util.symbol("cache_in_place")
26CALL_GEN_CACHE_KEY = util.symbol("call_gen_cache_key")
27STATIC_CACHE_KEY = util.symbol("static_cache_key")
28PROPAGATE_ATTRS = util.symbol("propagate_attrs")
29ANON_NAME = util.symbol("anon_name")
30
31
32def compare(obj1, obj2, **kw):
33 if kw.get("use_proxies", False):
34 strategy = ColIdentityComparatorStrategy()
35 else:
36 strategy = TraversalComparatorStrategy()
37
38 return strategy.compare(obj1, obj2, **kw)
39
40
41def _preconfigure_traversals(target_hierarchy):
42 for cls in util.walk_subclasses(target_hierarchy):
43 if hasattr(cls, "_traverse_internals"):
44 cls._generate_cache_attrs()
45 _copy_internals.generate_dispatch(
46 cls,
47 cls._traverse_internals,
48 "_generated_copy_internals_traversal",
49 )
50 _get_children.generate_dispatch(
51 cls,
52 cls._traverse_internals,
53 "_generated_get_children_traversal",
54 )
55
56
57class HasCacheKey(object):
58 """Mixin for objects which can produce a cache key.
59
60 .. seealso::
61
62 :class:`.CacheKey`
63
64 :ref:`sql_caching`
65
66 """
67
68 _cache_key_traversal = NO_CACHE
69
70 _is_has_cache_key = True
71
72 _hierarchy_supports_caching = True
73 """private attribute which may be set to False to prevent the
74 inherit_cache warning from being emitted for a hierarchy of subclasses.
75
76 Currently applies to the DDLElement hierarchy which does not implement
77 caching.
78
79 """
80
81 inherit_cache = None
82 """Indicate if this :class:`.HasCacheKey` instance should make use of the
83 cache key generation scheme used by its immediate superclass.
84
85 The attribute defaults to ``None``, which indicates that a construct has
86 not yet taken into account whether or not its appropriate for it to
87 participate in caching; this is functionally equivalent to setting the
88 value to ``False``, except that a warning is also emitted.
89
90 This flag can be set to ``True`` on a particular class, if the SQL that
91 corresponds to the object does not change based on attributes which
92 are local to this class, and not its superclass.
93
94 .. seealso::
95
96 :ref:`compilerext_caching` - General guideslines for setting the
97 :attr:`.HasCacheKey.inherit_cache` attribute for third-party or user
98 defined SQL constructs.
99
100 """
101
102 __slots__ = ()
103
104 @classmethod
105 def _generate_cache_attrs(cls):
106 """generate cache key dispatcher for a new class.
107
108 This sets the _generated_cache_key_traversal attribute once called
109 so should only be called once per class.
110
111 """
112 inherit_cache = cls.__dict__.get("inherit_cache", None)
113 inherit = bool(inherit_cache)
114
115 if inherit:
116 _cache_key_traversal = getattr(cls, "_cache_key_traversal", None)
117 if _cache_key_traversal is None:
118 try:
119 _cache_key_traversal = cls._traverse_internals
120 except AttributeError:
121 cls._generated_cache_key_traversal = NO_CACHE
122 return NO_CACHE
123
124 # TODO: wouldn't we instead get this from our superclass?
125 # also, our superclass may not have this yet, but in any case,
126 # we'd generate for the superclass that has it. this is a little
127 # more complicated, so for the moment this is a little less
128 # efficient on startup but simpler.
129 return _cache_key_traversal_visitor.generate_dispatch(
130 cls, _cache_key_traversal, "_generated_cache_key_traversal"
131 )
132 else:
133 _cache_key_traversal = cls.__dict__.get(
134 "_cache_key_traversal", None
135 )
136 if _cache_key_traversal is None:
137 _cache_key_traversal = cls.__dict__.get(
138 "_traverse_internals", None
139 )
140 if _cache_key_traversal is None:
141 cls._generated_cache_key_traversal = NO_CACHE
142 if (
143 inherit_cache is None
144 and cls._hierarchy_supports_caching
145 ):
146 util.warn(
147 "Class %s will not make use of SQL compilation "
148 "caching as it does not set the 'inherit_cache' "
149 "attribute to ``True``. This can have "
150 "significant performance implications including "
151 "some performance degradations in comparison to "
152 "prior SQLAlchemy versions. Set this attribute "
153 "to True if this object can make use of the cache "
154 "key generated by the superclass. Alternatively, "
155 "this attribute may be set to False which will "
156 "disable this warning." % (cls.__name__),
157 code="cprf",
158 )
159 return NO_CACHE
160
161 return _cache_key_traversal_visitor.generate_dispatch(
162 cls, _cache_key_traversal, "_generated_cache_key_traversal"
163 )
164
165 @util.preload_module("sqlalchemy.sql.elements")
166 def _gen_cache_key(self, anon_map, bindparams):
167 """return an optional cache key.
168
169 The cache key is a tuple which can contain any series of
170 objects that are hashable and also identifies
171 this object uniquely within the presence of a larger SQL expression
172 or statement, for the purposes of caching the resulting query.
173
174 The cache key should be based on the SQL compiled structure that would
175 ultimately be produced. That is, two structures that are composed in
176 exactly the same way should produce the same cache key; any difference
177 in the structures that would affect the SQL string or the type handlers
178 should result in a different cache key.
179
180 If a structure cannot produce a useful cache key, the NO_CACHE
181 symbol should be added to the anon_map and the method should
182 return None.
183
184 """
185
186 idself = id(self)
187 cls = self.__class__
188
189 if idself in anon_map:
190 return (anon_map[idself], cls)
191 else:
192 # inline of
193 # id_ = anon_map[idself]
194 anon_map[idself] = id_ = str(anon_map.index)
195 anon_map.index += 1
196
197 try:
198 dispatcher = cls.__dict__["_generated_cache_key_traversal"]
199 except KeyError:
200 # most of the dispatchers are generated up front
201 # in sqlalchemy/sql/__init__.py ->
202 # traversals.py-> _preconfigure_traversals().
203 # this block will generate any remaining dispatchers.
204 dispatcher = cls._generate_cache_attrs()
205
206 if dispatcher is NO_CACHE:
207 anon_map[NO_CACHE] = True
208 return None
209
210 result = (id_, cls)
211
212 # inline of _cache_key_traversal_visitor.run_generated_dispatch()
213
214 for attrname, obj, meth in dispatcher(
215 self, _cache_key_traversal_visitor
216 ):
217 if obj is not None:
218 # TODO: see if C code can help here as Python lacks an
219 # efficient switch construct
220
221 if meth is STATIC_CACHE_KEY:
222 sck = obj._static_cache_key
223 if sck is NO_CACHE:
224 anon_map[NO_CACHE] = True
225 return None
226 result += (attrname, sck)
227 elif meth is ANON_NAME:
228 elements = util.preloaded.sql_elements
229 if isinstance(obj, elements._anonymous_label):
230 obj = obj.apply_map(anon_map)
231 result += (attrname, obj)
232 elif meth is CALL_GEN_CACHE_KEY:
233 result += (
234 attrname,
235 obj._gen_cache_key(anon_map, bindparams),
236 )
237
238 # remaining cache functions are against
239 # Python tuples, dicts, lists, etc. so we can skip
240 # if they are empty
241 elif obj:
242 if meth is CACHE_IN_PLACE:
243 result += (attrname, obj)
244 elif meth is PROPAGATE_ATTRS:
245 result += (
246 attrname,
247 obj["compile_state_plugin"],
248 obj["plugin_subject"]._gen_cache_key(
249 anon_map, bindparams
250 )
251 if obj["plugin_subject"]
252 else None,
253 )
254 elif meth is InternalTraversal.dp_annotations_key:
255 # obj is here is the _annotations dict. Table uses
256 # a memoized version of it. however in other cases,
257 # we generate it given anon_map as we may be from a
258 # Join, Aliased, etc.
259 # see #8790
260
261 if self._gen_static_annotations_cache_key: # type: ignore # noqa: E501
262 result += self._annotations_cache_key # type: ignore # noqa: E501
263 else:
264 result += self._gen_annotations_cache_key(anon_map) # type: ignore # noqa: E501
265 elif (
266 meth is InternalTraversal.dp_clauseelement_list
267 or meth is InternalTraversal.dp_clauseelement_tuple
268 or meth
269 is InternalTraversal.dp_memoized_select_entities
270 ):
271 result += (
272 attrname,
273 tuple(
274 [
275 elem._gen_cache_key(anon_map, bindparams)
276 for elem in obj
277 ]
278 ),
279 )
280 else:
281 result += meth(
282 attrname, obj, self, anon_map, bindparams
283 )
284 return result
285
286 def _generate_cache_key(self):
287 """return a cache key.
288
289 The cache key is a tuple which can contain any series of
290 objects that are hashable and also identifies
291 this object uniquely within the presence of a larger SQL expression
292 or statement, for the purposes of caching the resulting query.
293
294 The cache key should be based on the SQL compiled structure that would
295 ultimately be produced. That is, two structures that are composed in
296 exactly the same way should produce the same cache key; any difference
297 in the structures that would affect the SQL string or the type handlers
298 should result in a different cache key.
299
300 The cache key returned by this method is an instance of
301 :class:`.CacheKey`, which consists of a tuple representing the
302 cache key, as well as a list of :class:`.BindParameter` objects
303 which are extracted from the expression. While two expressions
304 that produce identical cache key tuples will themselves generate
305 identical SQL strings, the list of :class:`.BindParameter` objects
306 indicates the bound values which may have different values in
307 each one; these bound parameters must be consulted in order to
308 execute the statement with the correct parameters.
309
310 a :class:`_expression.ClauseElement` structure that does not implement
311 a :meth:`._gen_cache_key` method and does not implement a
312 :attr:`.traverse_internals` attribute will not be cacheable; when
313 such an element is embedded into a larger structure, this method
314 will return None, indicating no cache key is available.
315
316 """
317
318 bindparams = []
319
320 _anon_map = anon_map()
321 key = self._gen_cache_key(_anon_map, bindparams)
322 if NO_CACHE in _anon_map:
323 return None
324 else:
325 return CacheKey(key, bindparams)
326
327 @classmethod
328 def _generate_cache_key_for_object(cls, obj):
329 bindparams = []
330
331 _anon_map = anon_map()
332 key = obj._gen_cache_key(_anon_map, bindparams)
333 if NO_CACHE in _anon_map:
334 return None
335 else:
336 return CacheKey(key, bindparams)
337
338
339class MemoizedHasCacheKey(HasCacheKey, HasMemoized):
340 @HasMemoized.memoized_instancemethod
341 def _generate_cache_key(self):
342 return HasCacheKey._generate_cache_key(self)
343
344
345class CacheKey(namedtuple("CacheKey", ["key", "bindparams"])):
346 """The key used to identify a SQL statement construct in the
347 SQL compilation cache.
348
349 .. seealso::
350
351 :ref:`sql_caching`
352
353 """
354
355 def __hash__(self):
356 """CacheKey itself is not hashable - hash the .key portion"""
357
358 return None
359
360 def to_offline_string(self, statement_cache, statement, parameters):
361 """Generate an "offline string" form of this :class:`.CacheKey`
362
363 The "offline string" is basically the string SQL for the
364 statement plus a repr of the bound parameter values in series.
365 Whereas the :class:`.CacheKey` object is dependent on in-memory
366 identities in order to work as a cache key, the "offline" version
367 is suitable for a cache that will work for other processes as well.
368
369 The given ``statement_cache`` is a dictionary-like object where the
370 string form of the statement itself will be cached. This dictionary
371 should be in a longer lived scope in order to reduce the time spent
372 stringifying statements.
373
374
375 """
376 if self.key not in statement_cache:
377 statement_cache[self.key] = sql_str = str(statement)
378 else:
379 sql_str = statement_cache[self.key]
380
381 if not self.bindparams:
382 param_tuple = tuple(parameters[key] for key in sorted(parameters))
383 else:
384 param_tuple = tuple(
385 parameters.get(bindparam.key, bindparam.value)
386 for bindparam in self.bindparams
387 )
388
389 return repr((sql_str, param_tuple))
390
391 def __eq__(self, other):
392 return bool(self.key == other.key)
393
394 def __ne__(self, other):
395 return not (self.key == other.key)
396
397 @classmethod
398 def _diff_tuples(cls, left, right):
399 ck1 = CacheKey(left, [])
400 ck2 = CacheKey(right, [])
401 return ck1._diff(ck2)
402
403 def _whats_different(self, other):
404
405 k1 = self.key
406 k2 = other.key
407
408 stack = []
409 pickup_index = 0
410 while True:
411 s1, s2 = k1, k2
412 for idx in stack:
413 s1 = s1[idx]
414 s2 = s2[idx]
415
416 for idx, (e1, e2) in enumerate(util.zip_longest(s1, s2)):
417 if idx < pickup_index:
418 continue
419 if e1 != e2:
420 if isinstance(e1, tuple) and isinstance(e2, tuple):
421 stack.append(idx)
422 break
423 else:
424 yield "key%s[%d]: %s != %s" % (
425 "".join("[%d]" % id_ for id_ in stack),
426 idx,
427 e1,
428 e2,
429 )
430 else:
431 pickup_index = stack.pop(-1)
432 break
433
434 def _diff(self, other):
435 return ", ".join(self._whats_different(other))
436
437 def __str__(self):
438 stack = [self.key]
439
440 output = []
441 sentinel = object()
442 indent = -1
443 while stack:
444 elem = stack.pop(0)
445 if elem is sentinel:
446 output.append((" " * (indent * 2)) + "),")
447 indent -= 1
448 elif isinstance(elem, tuple):
449 if not elem:
450 output.append((" " * ((indent + 1) * 2)) + "()")
451 else:
452 indent += 1
453 stack = list(elem) + [sentinel] + stack
454 output.append((" " * (indent * 2)) + "(")
455 else:
456 if isinstance(elem, HasCacheKey):
457 repr_ = "<%s object at %s>" % (
458 type(elem).__name__,
459 hex(id(elem)),
460 )
461 else:
462 repr_ = repr(elem)
463 output.append((" " * (indent * 2)) + " " + repr_ + ", ")
464
465 return "CacheKey(key=%s)" % ("\n".join(output),)
466
467 def _generate_param_dict(self):
468 """used for testing"""
469
470 from .compiler import prefix_anon_map
471
472 _anon_map = prefix_anon_map()
473 return {b.key % _anon_map: b.effective_value for b in self.bindparams}
474
475 def _apply_params_to_element(self, original_cache_key, target_element):
476 translate = {
477 k.key: v.value
478 for k, v in zip(original_cache_key.bindparams, self.bindparams)
479 }
480
481 return target_element.params(translate)
482
483
484def _clone(element, **kw):
485 return element._clone()
486
487
488class _CacheKey(ExtendedInternalTraversal):
489 # very common elements are inlined into the main _get_cache_key() method
490 # to produce a dramatic savings in Python function call overhead
491
492 visit_has_cache_key = visit_clauseelement = CALL_GEN_CACHE_KEY
493 visit_clauseelement_list = InternalTraversal.dp_clauseelement_list
494 visit_annotations_key = InternalTraversal.dp_annotations_key
495 visit_clauseelement_tuple = InternalTraversal.dp_clauseelement_tuple
496 visit_memoized_select_entities = (
497 InternalTraversal.dp_memoized_select_entities
498 )
499
500 visit_string = (
501 visit_boolean
502 ) = visit_operator = visit_plain_obj = CACHE_IN_PLACE
503 visit_statement_hint_list = CACHE_IN_PLACE
504 visit_type = STATIC_CACHE_KEY
505 visit_anon_name = ANON_NAME
506
507 visit_propagate_attrs = PROPAGATE_ATTRS
508
509 def visit_with_context_options(
510 self, attrname, obj, parent, anon_map, bindparams
511 ):
512 return tuple((fn.__code__, c_key) for fn, c_key in obj)
513
514 def visit_inspectable(self, attrname, obj, parent, anon_map, bindparams):
515 return (attrname, inspect(obj)._gen_cache_key(anon_map, bindparams))
516
517 def visit_string_list(self, attrname, obj, parent, anon_map, bindparams):
518 return tuple(obj)
519
520 def visit_multi(self, attrname, obj, parent, anon_map, bindparams):
521 return (
522 attrname,
523 obj._gen_cache_key(anon_map, bindparams)
524 if isinstance(obj, HasCacheKey)
525 else obj,
526 )
527
528 def visit_multi_list(self, attrname, obj, parent, anon_map, bindparams):
529 return (
530 attrname,
531 tuple(
532 elem._gen_cache_key(anon_map, bindparams)
533 if isinstance(elem, HasCacheKey)
534 else elem
535 for elem in obj
536 ),
537 )
538
539 def visit_has_cache_key_tuples(
540 self, attrname, obj, parent, anon_map, bindparams
541 ):
542 if not obj:
543 return ()
544 return (
545 attrname,
546 tuple(
547 tuple(
548 elem._gen_cache_key(anon_map, bindparams)
549 for elem in tup_elem
550 )
551 for tup_elem in obj
552 ),
553 )
554
555 def visit_has_cache_key_list(
556 self, attrname, obj, parent, anon_map, bindparams
557 ):
558 if not obj:
559 return ()
560 return (
561 attrname,
562 tuple(elem._gen_cache_key(anon_map, bindparams) for elem in obj),
563 )
564
565 def visit_executable_options(
566 self, attrname, obj, parent, anon_map, bindparams
567 ):
568 if not obj:
569 return ()
570 return (
571 attrname,
572 tuple(
573 elem._gen_cache_key(anon_map, bindparams)
574 for elem in obj
575 if elem._is_has_cache_key
576 ),
577 )
578
579 def visit_inspectable_list(
580 self, attrname, obj, parent, anon_map, bindparams
581 ):
582 return self.visit_has_cache_key_list(
583 attrname, [inspect(o) for o in obj], parent, anon_map, bindparams
584 )
585
586 def visit_clauseelement_tuples(
587 self, attrname, obj, parent, anon_map, bindparams
588 ):
589 return self.visit_has_cache_key_tuples(
590 attrname, obj, parent, anon_map, bindparams
591 )
592
593 def visit_fromclause_ordered_set(
594 self, attrname, obj, parent, anon_map, bindparams
595 ):
596 if not obj:
597 return ()
598 return (
599 attrname,
600 tuple([elem._gen_cache_key(anon_map, bindparams) for elem in obj]),
601 )
602
603 def visit_clauseelement_unordered_set(
604 self, attrname, obj, parent, anon_map, bindparams
605 ):
606 if not obj:
607 return ()
608 cache_keys = [
609 elem._gen_cache_key(anon_map, bindparams) for elem in obj
610 ]
611 return (
612 attrname,
613 tuple(
614 sorted(cache_keys)
615 ), # cache keys all start with (id_, class)
616 )
617
618 def visit_named_ddl_element(
619 self, attrname, obj, parent, anon_map, bindparams
620 ):
621 return (attrname, obj.name)
622
623 def visit_prefix_sequence(
624 self, attrname, obj, parent, anon_map, bindparams
625 ):
626 if not obj:
627 return ()
628
629 return (
630 attrname,
631 tuple(
632 [
633 (clause._gen_cache_key(anon_map, bindparams), strval)
634 for clause, strval in obj
635 ]
636 ),
637 )
638
639 def visit_setup_join_tuple(
640 self, attrname, obj, parent, anon_map, bindparams
641 ):
642 is_legacy = "legacy" in attrname
643
644 return tuple(
645 (
646 target
647 if is_legacy and isinstance(target, str)
648 else target._gen_cache_key(anon_map, bindparams),
649 onclause
650 if is_legacy and isinstance(onclause, str)
651 else onclause._gen_cache_key(anon_map, bindparams)
652 if onclause is not None
653 else None,
654 from_._gen_cache_key(anon_map, bindparams)
655 if from_ is not None
656 else None,
657 tuple([(key, flags[key]) for key in sorted(flags)]),
658 )
659 for (target, onclause, from_, flags) in obj
660 )
661
662 def visit_table_hint_list(
663 self, attrname, obj, parent, anon_map, bindparams
664 ):
665 if not obj:
666 return ()
667
668 return (
669 attrname,
670 tuple(
671 [
672 (
673 clause._gen_cache_key(anon_map, bindparams),
674 dialect_name,
675 text,
676 )
677 for (clause, dialect_name), text in obj.items()
678 ]
679 ),
680 )
681
682 def visit_plain_dict(self, attrname, obj, parent, anon_map, bindparams):
683 return (attrname, tuple([(key, obj[key]) for key in sorted(obj)]))
684
685 def visit_dialect_options(
686 self, attrname, obj, parent, anon_map, bindparams
687 ):
688 return (
689 attrname,
690 tuple(
691 (
692 dialect_name,
693 tuple(
694 [
695 (key, obj[dialect_name][key])
696 for key in sorted(obj[dialect_name])
697 ]
698 ),
699 )
700 for dialect_name in sorted(obj)
701 ),
702 )
703
704 def visit_string_clauseelement_dict(
705 self, attrname, obj, parent, anon_map, bindparams
706 ):
707 return (
708 attrname,
709 tuple(
710 (key, obj[key]._gen_cache_key(anon_map, bindparams))
711 for key in sorted(obj)
712 ),
713 )
714
715 def visit_string_multi_dict(
716 self, attrname, obj, parent, anon_map, bindparams
717 ):
718 return (
719 attrname,
720 tuple(
721 (
722 key,
723 value._gen_cache_key(anon_map, bindparams)
724 if isinstance(value, HasCacheKey)
725 else value,
726 )
727 for key, value in [(key, obj[key]) for key in sorted(obj)]
728 ),
729 )
730
731 def visit_fromclause_canonical_column_collection(
732 self, attrname, obj, parent, anon_map, bindparams
733 ):
734 # inlining into the internals of ColumnCollection
735 return (
736 attrname,
737 tuple(
738 col._gen_cache_key(anon_map, bindparams)
739 for k, col in obj._collection
740 ),
741 )
742
743 def visit_unknown_structure(
744 self, attrname, obj, parent, anon_map, bindparams
745 ):
746 anon_map[NO_CACHE] = True
747 return ()
748
749 def visit_dml_ordered_values(
750 self, attrname, obj, parent, anon_map, bindparams
751 ):
752 return (
753 attrname,
754 tuple(
755 (
756 key._gen_cache_key(anon_map, bindparams)
757 if hasattr(key, "__clause_element__")
758 else key,
759 value._gen_cache_key(anon_map, bindparams),
760 )
761 for key, value in obj
762 ),
763 )
764
765 def visit_dml_values(self, attrname, obj, parent, anon_map, bindparams):
766 if py37:
767 # in py37 we can assume two dictionaries created in the same
768 # insert ordering will retain that sorting
769 return (
770 attrname,
771 tuple(
772 (
773 k._gen_cache_key(anon_map, bindparams)
774 if hasattr(k, "__clause_element__")
775 else k,
776 obj[k]._gen_cache_key(anon_map, bindparams),
777 )
778 for k in obj
779 ),
780 )
781 else:
782 expr_values = {k for k in obj if hasattr(k, "__clause_element__")}
783 if expr_values:
784 # expr values can't be sorted deterministically right now,
785 # so no cache
786 anon_map[NO_CACHE] = True
787 return ()
788
789 str_values = expr_values.symmetric_difference(obj)
790
791 return (
792 attrname,
793 tuple(
794 (k, obj[k]._gen_cache_key(anon_map, bindparams))
795 for k in sorted(str_values)
796 ),
797 )
798
799 def visit_dml_multi_values(
800 self, attrname, obj, parent, anon_map, bindparams
801 ):
802 # multivalues are simply not cacheable right now
803 anon_map[NO_CACHE] = True
804 return ()
805
806
807_cache_key_traversal_visitor = _CacheKey()
808
809
810class HasCopyInternals(object):
811 def _clone(self, **kw):
812 raise NotImplementedError()
813
814 def _copy_internals(self, omit_attrs=(), **kw):
815 """Reassign internal elements to be clones of themselves.
816
817 Called during a copy-and-traverse operation on newly
818 shallow-copied elements to create a deep copy.
819
820 The given clone function should be used, which may be applying
821 additional transformations to the element (i.e. replacement
822 traversal, cloned traversal, annotations).
823
824 """
825
826 try:
827 traverse_internals = self._traverse_internals
828 except AttributeError:
829 # user-defined classes may not have a _traverse_internals
830 return
831
832 for attrname, obj, meth in _copy_internals.run_generated_dispatch(
833 self, traverse_internals, "_generated_copy_internals_traversal"
834 ):
835 if attrname in omit_attrs:
836 continue
837
838 if obj is not None:
839 result = meth(attrname, self, obj, **kw)
840 if result is not None:
841 setattr(self, attrname, result)
842
843
844class _CopyInternals(InternalTraversal):
845 """Generate a _copy_internals internal traversal dispatch for classes
846 with a _traverse_internals collection."""
847
848 def visit_clauseelement(
849 self, attrname, parent, element, clone=_clone, **kw
850 ):
851 return clone(element, **kw)
852
853 def visit_clauseelement_list(
854 self, attrname, parent, element, clone=_clone, **kw
855 ):
856 return [clone(clause, **kw) for clause in element]
857
858 def visit_clauseelement_tuple(
859 self, attrname, parent, element, clone=_clone, **kw
860 ):
861 return tuple([clone(clause, **kw) for clause in element])
862
863 def visit_executable_options(
864 self, attrname, parent, element, clone=_clone, **kw
865 ):
866 return tuple([clone(clause, **kw) for clause in element])
867
868 def visit_clauseelement_unordered_set(
869 self, attrname, parent, element, clone=_clone, **kw
870 ):
871 return {clone(clause, **kw) for clause in element}
872
873 def visit_clauseelement_tuples(
874 self, attrname, parent, element, clone=_clone, **kw
875 ):
876 return [
877 tuple(clone(tup_elem, **kw) for tup_elem in elem)
878 for elem in element
879 ]
880
881 def visit_string_clauseelement_dict(
882 self, attrname, parent, element, clone=_clone, **kw
883 ):
884 return dict(
885 (key, clone(value, **kw)) for key, value in element.items()
886 )
887
888 def visit_setup_join_tuple(
889 self, attrname, parent, element, clone=_clone, **kw
890 ):
891 return tuple(
892 (
893 clone(target, **kw) if target is not None else None,
894 clone(onclause, **kw) if onclause is not None else None,
895 clone(from_, **kw) if from_ is not None else None,
896 flags,
897 )
898 for (target, onclause, from_, flags) in element
899 )
900
901 def visit_memoized_select_entities(self, attrname, parent, element, **kw):
902 return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
903
904 def visit_dml_ordered_values(
905 self, attrname, parent, element, clone=_clone, **kw
906 ):
907 # sequence of 2-tuples
908 return [
909 (
910 clone(key, **kw)
911 if hasattr(key, "__clause_element__")
912 else key,
913 clone(value, **kw),
914 )
915 for key, value in element
916 ]
917
918 def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
919 return {
920 (
921 clone(key, **kw) if hasattr(key, "__clause_element__") else key
922 ): clone(value, **kw)
923 for key, value in element.items()
924 }
925
926 def visit_dml_multi_values(
927 self, attrname, parent, element, clone=_clone, **kw
928 ):
929 # sequence of sequences, each sequence contains a list/dict/tuple
930
931 def copy(elem):
932 if isinstance(elem, (list, tuple)):
933 return [
934 clone(value, **kw)
935 if hasattr(value, "__clause_element__")
936 else value
937 for value in elem
938 ]
939 elif isinstance(elem, dict):
940 return {
941 (
942 clone(key, **kw)
943 if hasattr(key, "__clause_element__")
944 else key
945 ): (
946 clone(value, **kw)
947 if hasattr(value, "__clause_element__")
948 else value
949 )
950 for key, value in elem.items()
951 }
952 else:
953 # TODO: use abc classes
954 assert False
955
956 return [
957 [copy(sub_element) for sub_element in sequence]
958 for sequence in element
959 ]
960
961 def visit_propagate_attrs(
962 self, attrname, parent, element, clone=_clone, **kw
963 ):
964 return element
965
966
967_copy_internals = _CopyInternals()
968
969
970def _flatten_clauseelement(element):
971 while hasattr(element, "__clause_element__") and not getattr(
972 element, "is_clause_element", False
973 ):
974 element = element.__clause_element__()
975
976 return element
977
978
979class _GetChildren(InternalTraversal):
980 """Generate a _children_traversal internal traversal dispatch for classes
981 with a _traverse_internals collection."""
982
983 def visit_has_cache_key(self, element, **kw):
984 # the GetChildren traversal refers explicitly to ClauseElement
985 # structures. Within these, a plain HasCacheKey is not a
986 # ClauseElement, so don't include these.
987 return ()
988
989 def visit_clauseelement(self, element, **kw):
990 return (element,)
991
992 def visit_clauseelement_list(self, element, **kw):
993 return element
994
995 def visit_clauseelement_tuple(self, element, **kw):
996 return element
997
998 def visit_clauseelement_tuples(self, element, **kw):
999 return itertools.chain.from_iterable(element)
1000
1001 def visit_fromclause_canonical_column_collection(self, element, **kw):
1002 return ()
1003
1004 def visit_string_clauseelement_dict(self, element, **kw):
1005 return element.values()
1006
1007 def visit_fromclause_ordered_set(self, element, **kw):
1008 return element
1009
1010 def visit_clauseelement_unordered_set(self, element, **kw):
1011 return element
1012
1013 def visit_setup_join_tuple(self, element, **kw):
1014 for (target, onclause, from_, flags) in element:
1015 if from_ is not None:
1016 yield from_
1017
1018 if not isinstance(target, str):
1019 yield _flatten_clauseelement(target)
1020
1021 if onclause is not None and not isinstance(onclause, str):
1022 yield _flatten_clauseelement(onclause)
1023
1024 def visit_memoized_select_entities(self, element, **kw):
1025 return self.visit_clauseelement_tuple(element, **kw)
1026
1027 def visit_dml_ordered_values(self, element, **kw):
1028 for k, v in element:
1029 if hasattr(k, "__clause_element__"):
1030 yield k
1031 yield v
1032
1033 def visit_dml_values(self, element, **kw):
1034 expr_values = {k for k in element if hasattr(k, "__clause_element__")}
1035 str_values = expr_values.symmetric_difference(element)
1036
1037 for k in sorted(str_values):
1038 yield element[k]
1039 for k in expr_values:
1040 yield k
1041 yield element[k]
1042
1043 def visit_dml_multi_values(self, element, **kw):
1044 return ()
1045
1046 def visit_propagate_attrs(self, element, **kw):
1047 return ()
1048
1049
1050_get_children = _GetChildren()
1051
1052
1053@util.preload_module("sqlalchemy.sql.elements")
1054def _resolve_name_for_compare(element, name, anon_map, **kw):
1055 if isinstance(name, util.preloaded.sql_elements._anonymous_label):
1056 name = name.apply_map(anon_map)
1057
1058 return name
1059
1060
1061class anon_map(dict):
1062 """A map that creates new keys for missing key access.
1063
1064 Produces an incrementing sequence given a series of unique keys.
1065
1066 This is similar to the compiler prefix_anon_map class although simpler.
1067
1068 Inlines the approach taken by :class:`sqlalchemy.util.PopulateDict` which
1069 is otherwise usually used for this type of operation.
1070
1071 """
1072
1073 def __init__(self):
1074 self.index = 0
1075
1076 def __missing__(self, key):
1077 self[key] = val = str(self.index)
1078 self.index += 1
1079 return val
1080
1081
1082class TraversalComparatorStrategy(InternalTraversal, util.MemoizedSlots):
1083 __slots__ = "stack", "cache", "anon_map"
1084
1085 def __init__(self):
1086 self.stack = deque()
1087 self.cache = set()
1088
1089 def _memoized_attr_anon_map(self):
1090 return (anon_map(), anon_map())
1091
1092 def compare(self, obj1, obj2, **kw):
1093 stack = self.stack
1094 cache = self.cache
1095
1096 compare_annotations = kw.get("compare_annotations", False)
1097
1098 stack.append((obj1, obj2))
1099
1100 while stack:
1101 left, right = stack.popleft()
1102
1103 if left is right:
1104 continue
1105 elif left is None or right is None:
1106 # we know they are different so no match
1107 return False
1108 elif (left, right) in cache:
1109 continue
1110 cache.add((left, right))
1111
1112 visit_name = left.__visit_name__
1113 if visit_name != right.__visit_name__:
1114 return False
1115
1116 meth = getattr(self, "compare_%s" % visit_name, None)
1117
1118 if meth:
1119 attributes_compared = meth(left, right, **kw)
1120 if attributes_compared is COMPARE_FAILED:
1121 return False
1122 elif attributes_compared is SKIP_TRAVERSE:
1123 continue
1124
1125 # attributes_compared is returned as a list of attribute
1126 # names that were "handled" by the comparison method above.
1127 # remaining attribute names in the _traverse_internals
1128 # will be compared.
1129 else:
1130 attributes_compared = ()
1131
1132 for (
1133 (left_attrname, left_visit_sym),
1134 (right_attrname, right_visit_sym),
1135 ) in util.zip_longest(
1136 left._traverse_internals,
1137 right._traverse_internals,
1138 fillvalue=(None, None),
1139 ):
1140 if not compare_annotations and (
1141 (left_attrname == "_annotations")
1142 or (right_attrname == "_annotations")
1143 ):
1144 continue
1145
1146 if (
1147 left_attrname != right_attrname
1148 or left_visit_sym is not right_visit_sym
1149 ):
1150 return False
1151 elif left_attrname in attributes_compared:
1152 continue
1153
1154 dispatch = self.dispatch(left_visit_sym)
1155 left_child = operator.attrgetter(left_attrname)(left)
1156 right_child = operator.attrgetter(right_attrname)(right)
1157 if left_child is None:
1158 if right_child is not None:
1159 return False
1160 else:
1161 continue
1162
1163 comparison = dispatch(
1164 left_attrname, left, left_child, right, right_child, **kw
1165 )
1166 if comparison is COMPARE_FAILED:
1167 return False
1168
1169 return True
1170
1171 def compare_inner(self, obj1, obj2, **kw):
1172 comparator = self.__class__()
1173 return comparator.compare(obj1, obj2, **kw)
1174
1175 def visit_has_cache_key(
1176 self, attrname, left_parent, left, right_parent, right, **kw
1177 ):
1178 if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
1179 self.anon_map[1], []
1180 ):
1181 return COMPARE_FAILED
1182
1183 def visit_propagate_attrs(
1184 self, attrname, left_parent, left, right_parent, right, **kw
1185 ):
1186 return self.compare_inner(
1187 left.get("plugin_subject", None), right.get("plugin_subject", None)
1188 )
1189
1190 def visit_has_cache_key_list(
1191 self, attrname, left_parent, left, right_parent, right, **kw
1192 ):
1193 for l, r in util.zip_longest(left, right, fillvalue=None):
1194 if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
1195 self.anon_map[1], []
1196 ):
1197 return COMPARE_FAILED
1198
1199 def visit_executable_options(
1200 self, attrname, left_parent, left, right_parent, right, **kw
1201 ):
1202 for l, r in util.zip_longest(left, right, fillvalue=None):
1203 if (
1204 l._gen_cache_key(self.anon_map[0], [])
1205 if l._is_has_cache_key
1206 else l
1207 ) != (
1208 r._gen_cache_key(self.anon_map[1], [])
1209 if r._is_has_cache_key
1210 else r
1211 ):
1212 return COMPARE_FAILED
1213
1214 def visit_clauseelement(
1215 self, attrname, left_parent, left, right_parent, right, **kw
1216 ):
1217 self.stack.append((left, right))
1218
1219 def visit_fromclause_canonical_column_collection(
1220 self, attrname, left_parent, left, right_parent, right, **kw
1221 ):
1222 for lcol, rcol in util.zip_longest(left, right, fillvalue=None):
1223 self.stack.append((lcol, rcol))
1224
1225 def visit_fromclause_derived_column_collection(
1226 self, attrname, left_parent, left, right_parent, right, **kw
1227 ):
1228 pass
1229
1230 def visit_string_clauseelement_dict(
1231 self, attrname, left_parent, left, right_parent, right, **kw
1232 ):
1233 for lstr, rstr in util.zip_longest(
1234 sorted(left), sorted(right), fillvalue=None
1235 ):
1236 if lstr != rstr:
1237 return COMPARE_FAILED
1238 self.stack.append((left[lstr], right[rstr]))
1239
1240 def visit_clauseelement_tuples(
1241 self, attrname, left_parent, left, right_parent, right, **kw
1242 ):
1243 for ltup, rtup in util.zip_longest(left, right, fillvalue=None):
1244 if ltup is None or rtup is None:
1245 return COMPARE_FAILED
1246
1247 for l, r in util.zip_longest(ltup, rtup, fillvalue=None):
1248 self.stack.append((l, r))
1249
1250 def visit_clauseelement_list(
1251 self, attrname, left_parent, left, right_parent, right, **kw
1252 ):
1253 for l, r in util.zip_longest(left, right, fillvalue=None):
1254 self.stack.append((l, r))
1255
1256 def visit_clauseelement_tuple(
1257 self, attrname, left_parent, left, right_parent, right, **kw
1258 ):
1259 for l, r in util.zip_longest(left, right, fillvalue=None):
1260 self.stack.append((l, r))
1261
1262 def _compare_unordered_sequences(self, seq1, seq2, **kw):
1263 if seq1 is None:
1264 return seq2 is None
1265
1266 completed = set()
1267 for clause in seq1:
1268 for other_clause in set(seq2).difference(completed):
1269 if self.compare_inner(clause, other_clause, **kw):
1270 completed.add(other_clause)
1271 break
1272 return len(completed) == len(seq1) == len(seq2)
1273
1274 def visit_clauseelement_unordered_set(
1275 self, attrname, left_parent, left, right_parent, right, **kw
1276 ):
1277 return self._compare_unordered_sequences(left, right, **kw)
1278
1279 def visit_fromclause_ordered_set(
1280 self, attrname, left_parent, left, right_parent, right, **kw
1281 ):
1282 for l, r in util.zip_longest(left, right, fillvalue=None):
1283 self.stack.append((l, r))
1284
1285 def visit_string(
1286 self, attrname, left_parent, left, right_parent, right, **kw
1287 ):
1288 return left == right
1289
1290 def visit_string_list(
1291 self, attrname, left_parent, left, right_parent, right, **kw
1292 ):
1293 return left == right
1294
1295 def visit_anon_name(
1296 self, attrname, left_parent, left, right_parent, right, **kw
1297 ):
1298 return _resolve_name_for_compare(
1299 left_parent, left, self.anon_map[0], **kw
1300 ) == _resolve_name_for_compare(
1301 right_parent, right, self.anon_map[1], **kw
1302 )
1303
1304 def visit_boolean(
1305 self, attrname, left_parent, left, right_parent, right, **kw
1306 ):
1307 return left == right
1308
1309 def visit_operator(
1310 self, attrname, left_parent, left, right_parent, right, **kw
1311 ):
1312 return left == right
1313
1314 def visit_type(
1315 self, attrname, left_parent, left, right_parent, right, **kw
1316 ):
1317 return left._compare_type_affinity(right)
1318
1319 def visit_plain_dict(
1320 self, attrname, left_parent, left, right_parent, right, **kw
1321 ):
1322 return left == right
1323
1324 def visit_dialect_options(
1325 self, attrname, left_parent, left, right_parent, right, **kw
1326 ):
1327 return left == right
1328
1329 def visit_annotations_key(
1330 self, attrname, left_parent, left, right_parent, right, **kw
1331 ):
1332 if left and right:
1333 return (
1334 left_parent._annotations_cache_key
1335 == right_parent._annotations_cache_key
1336 )
1337 else:
1338 return left == right
1339
1340 def visit_with_context_options(
1341 self, attrname, left_parent, left, right_parent, right, **kw
1342 ):
1343 return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
1344 (fn.__code__, c_key) for fn, c_key in right
1345 )
1346
1347 def visit_plain_obj(
1348 self, attrname, left_parent, left, right_parent, right, **kw
1349 ):
1350 return left == right
1351
1352 def visit_named_ddl_element(
1353 self, attrname, left_parent, left, right_parent, right, **kw
1354 ):
1355 if left is None:
1356 if right is not None:
1357 return COMPARE_FAILED
1358
1359 return left.name == right.name
1360
1361 def visit_prefix_sequence(
1362 self, attrname, left_parent, left, right_parent, right, **kw
1363 ):
1364 for (l_clause, l_str), (r_clause, r_str) in util.zip_longest(
1365 left, right, fillvalue=(None, None)
1366 ):
1367 if l_str != r_str:
1368 return COMPARE_FAILED
1369 else:
1370 self.stack.append((l_clause, r_clause))
1371
1372 def visit_setup_join_tuple(
1373 self, attrname, left_parent, left, right_parent, right, **kw
1374 ):
1375 # TODO: look at attrname for "legacy_join" and use different structure
1376 for (
1377 (l_target, l_onclause, l_from, l_flags),
1378 (r_target, r_onclause, r_from, r_flags),
1379 ) in util.zip_longest(left, right, fillvalue=(None, None, None, None)):
1380 if l_flags != r_flags:
1381 return COMPARE_FAILED
1382 self.stack.append((l_target, r_target))
1383 self.stack.append((l_onclause, r_onclause))
1384 self.stack.append((l_from, r_from))
1385
1386 def visit_memoized_select_entities(
1387 self, attrname, left_parent, left, right_parent, right, **kw
1388 ):
1389 return self.visit_clauseelement_tuple(
1390 attrname, left_parent, left, right_parent, right, **kw
1391 )
1392
1393 def visit_table_hint_list(
1394 self, attrname, left_parent, left, right_parent, right, **kw
1395 ):
1396 left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
1397 right_keys = sorted(
1398 right, key=lambda elem: (elem[0].fullname, elem[1])
1399 )
1400 for (ltable, ldialect), (rtable, rdialect) in util.zip_longest(
1401 left_keys, right_keys, fillvalue=(None, None)
1402 ):
1403 if ldialect != rdialect:
1404 return COMPARE_FAILED
1405 elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
1406 return COMPARE_FAILED
1407 else:
1408 self.stack.append((ltable, rtable))
1409
1410 def visit_statement_hint_list(
1411 self, attrname, left_parent, left, right_parent, right, **kw
1412 ):
1413 return left == right
1414
1415 def visit_unknown_structure(
1416 self, attrname, left_parent, left, right_parent, right, **kw
1417 ):
1418 raise NotImplementedError()
1419
1420 def visit_dml_ordered_values(
1421 self, attrname, left_parent, left, right_parent, right, **kw
1422 ):
1423 # sequence of tuple pairs
1424
1425 for (lk, lv), (rk, rv) in util.zip_longest(
1426 left, right, fillvalue=(None, None)
1427 ):
1428 if not self._compare_dml_values_or_ce(lk, rk, **kw):
1429 return COMPARE_FAILED
1430
1431 def _compare_dml_values_or_ce(self, lv, rv, **kw):
1432 lvce = hasattr(lv, "__clause_element__")
1433 rvce = hasattr(rv, "__clause_element__")
1434 if lvce != rvce:
1435 return False
1436 elif lvce and not self.compare_inner(lv, rv, **kw):
1437 return False
1438 elif not lvce and lv != rv:
1439 return False
1440 elif not self.compare_inner(lv, rv, **kw):
1441 return False
1442
1443 return True
1444
1445 def visit_dml_values(
1446 self, attrname, left_parent, left, right_parent, right, **kw
1447 ):
1448 if left is None or right is None or len(left) != len(right):
1449 return COMPARE_FAILED
1450
1451 if isinstance(left, collections_abc.Sequence):
1452 for lv, rv in zip(left, right):
1453 if not self._compare_dml_values_or_ce(lv, rv, **kw):
1454 return COMPARE_FAILED
1455 elif isinstance(right, collections_abc.Sequence):
1456 return COMPARE_FAILED
1457 elif py37:
1458 # dictionaries guaranteed to support insert ordering in
1459 # py37 so that we can compare the keys in order. without
1460 # this, we can't compare SQL expression keys because we don't
1461 # know which key is which
1462 for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
1463 if not self._compare_dml_values_or_ce(lk, rk, **kw):
1464 return COMPARE_FAILED
1465 if not self._compare_dml_values_or_ce(lv, rv, **kw):
1466 return COMPARE_FAILED
1467 else:
1468 for lk in left:
1469 lv = left[lk]
1470
1471 if lk not in right:
1472 return COMPARE_FAILED
1473 rv = right[lk]
1474
1475 if not self._compare_dml_values_or_ce(lv, rv, **kw):
1476 return COMPARE_FAILED
1477
1478 def visit_dml_multi_values(
1479 self, attrname, left_parent, left, right_parent, right, **kw
1480 ):
1481 for lseq, rseq in util.zip_longest(left, right, fillvalue=None):
1482 if lseq is None or rseq is None:
1483 return COMPARE_FAILED
1484
1485 for ld, rd in util.zip_longest(lseq, rseq, fillvalue=None):
1486 if (
1487 self.visit_dml_values(
1488 attrname, left_parent, ld, right_parent, rd, **kw
1489 )
1490 is COMPARE_FAILED
1491 ):
1492 return COMPARE_FAILED
1493
1494 def compare_clauselist(self, left, right, **kw):
1495 if left.operator is right.operator:
1496 if operators.is_associative(left.operator):
1497 if self._compare_unordered_sequences(
1498 left.clauses, right.clauses, **kw
1499 ):
1500 return ["operator", "clauses"]
1501 else:
1502 return COMPARE_FAILED
1503 else:
1504 return ["operator"]
1505 else:
1506 return COMPARE_FAILED
1507
1508 def compare_binary(self, left, right, **kw):
1509 if left.operator == right.operator:
1510 if operators.is_commutative(left.operator):
1511 if (
1512 self.compare_inner(left.left, right.left, **kw)
1513 and self.compare_inner(left.right, right.right, **kw)
1514 ) or (
1515 self.compare_inner(left.left, right.right, **kw)
1516 and self.compare_inner(left.right, right.left, **kw)
1517 ):
1518 return ["operator", "negate", "left", "right"]
1519 else:
1520 return COMPARE_FAILED
1521 else:
1522 return ["operator", "negate"]
1523 else:
1524 return COMPARE_FAILED
1525
1526 def compare_bindparam(self, left, right, **kw):
1527 compare_keys = kw.pop("compare_keys", True)
1528 compare_values = kw.pop("compare_values", True)
1529
1530 if compare_values:
1531 omit = []
1532 else:
1533 # this means, "skip these, we already compared"
1534 omit = ["callable", "value"]
1535
1536 if not compare_keys:
1537 omit.append("key")
1538
1539 return omit
1540
1541
1542class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
1543 def compare_column_element(
1544 self, left, right, use_proxies=True, equivalents=(), **kw
1545 ):
1546 """Compare ColumnElements using proxies and equivalent collections.
1547
1548 This is a comparison strategy specific to the ORM.
1549 """
1550
1551 to_compare = (right,)
1552 if equivalents and right in equivalents:
1553 to_compare = equivalents[right].union(to_compare)
1554
1555 for oth in to_compare:
1556 if use_proxies and left.shares_lineage(oth):
1557 return SKIP_TRAVERSE
1558 elif hash(left) == hash(right):
1559 return SKIP_TRAVERSE
1560 else:
1561 return COMPARE_FAILED
1562
1563 def compare_column(self, left, right, **kw):
1564 return self.compare_column_element(left, right, **kw)
1565
1566 def compare_label(self, left, right, **kw):
1567 return self.compare_column_element(left, right, **kw)
1568
1569 def compare_table(self, left, right, **kw):
1570 # tables compare on identity, since it's not really feasible to
1571 # compare them column by column with the above rules
1572 return SKIP_TRAVERSE if left is right else COMPARE_FAILED