1# sql/traversals.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11from collections import deque
12import collections.abc as collections_abc
13import itertools
14from itertools import zip_longest
15import operator
16import typing
17from typing import Any
18from typing import Callable
19from typing import Deque
20from typing import Dict
21from typing import Iterable
22from typing import Optional
23from typing import Set
24from typing import Tuple
25from typing import Type
26
27from . import operators
28from .cache_key import HasCacheKey
29from .visitors import _TraverseInternalsType
30from .visitors import anon_map
31from .visitors import ExternallyTraversible
32from .visitors import HasTraversalDispatch
33from .visitors import HasTraverseInternals
34from .. import util
35from ..util import langhelpers
36from ..util.typing import Self
37
38
39SKIP_TRAVERSE = util.symbol("skip_traverse")
40COMPARE_FAILED = False
41COMPARE_SUCCEEDED = True
42
43
44def compare(obj1: Any, obj2: Any, **kw: Any) -> bool:
45 strategy: TraversalComparatorStrategy
46 if kw.get("use_proxies", False):
47 strategy = ColIdentityComparatorStrategy()
48 else:
49 strategy = TraversalComparatorStrategy()
50
51 return strategy.compare(obj1, obj2, **kw)
52
53
54def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None:
55 for cls in util.walk_subclasses(target_hierarchy):
56 if hasattr(cls, "_generate_cache_attrs") and hasattr(
57 cls, "_traverse_internals"
58 ):
59 cls._generate_cache_attrs()
60 _copy_internals.generate_dispatch(
61 cls,
62 cls._traverse_internals,
63 "_generated_copy_internals_traversal",
64 )
65 _get_children.generate_dispatch(
66 cls,
67 cls._traverse_internals,
68 "_generated_get_children_traversal",
69 )
70
71
72class HasShallowCopy(HasTraverseInternals):
73 """attribute-wide operations that are useful for classes that use
74 __slots__ and therefore can't operate on their attributes in a dictionary.
75
76
77 """
78
79 __slots__ = ()
80
81 if typing.TYPE_CHECKING:
82
83 def _generated_shallow_copy_traversal(self, other: Self) -> None: ...
84
85 def _generated_shallow_from_dict_traversal(
86 self, d: Dict[str, Any]
87 ) -> None: ...
88
89 def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: ...
90
91 @classmethod
92 def _generate_shallow_copy(
93 cls,
94 internal_dispatch: _TraverseInternalsType,
95 method_name: str,
96 ) -> Callable[[Self, Self], None]:
97 code = "\n".join(
98 f" other.{attrname} = self.{attrname}"
99 for attrname, _ in internal_dispatch
100 )
101 meth_text = f"def {method_name}(self, other):\n{code}\n"
102 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
103
104 @classmethod
105 def _generate_shallow_to_dict(
106 cls,
107 internal_dispatch: _TraverseInternalsType,
108 method_name: str,
109 ) -> Callable[[Self], Dict[str, Any]]:
110 code = ",\n".join(
111 f" '{attrname}': self.{attrname}"
112 for attrname, _ in internal_dispatch
113 )
114 meth_text = f"def {method_name}(self):\n return {{{code}}}\n"
115 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
116
117 @classmethod
118 def _generate_shallow_from_dict(
119 cls,
120 internal_dispatch: _TraverseInternalsType,
121 method_name: str,
122 ) -> Callable[[Self, Dict[str, Any]], None]:
123 code = "\n".join(
124 f" self.{attrname} = d['{attrname}']"
125 for attrname, _ in internal_dispatch
126 )
127 meth_text = f"def {method_name}(self, d):\n{code}\n"
128 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
129
130 def _shallow_from_dict(self, d: Dict[str, Any]) -> None:
131 cls = self.__class__
132
133 shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None]
134 try:
135 shallow_from_dict = cls.__dict__[
136 "_generated_shallow_from_dict_traversal"
137 ]
138 except KeyError:
139 shallow_from_dict = self._generate_shallow_from_dict(
140 cls._traverse_internals,
141 "_generated_shallow_from_dict_traversal",
142 )
143
144 cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa: E501
145
146 shallow_from_dict(self, d)
147
148 def _shallow_to_dict(self) -> Dict[str, Any]:
149 cls = self.__class__
150
151 shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]]
152
153 try:
154 shallow_to_dict = cls.__dict__[
155 "_generated_shallow_to_dict_traversal"
156 ]
157 except KeyError:
158 shallow_to_dict = self._generate_shallow_to_dict(
159 cls._traverse_internals, "_generated_shallow_to_dict_traversal"
160 )
161
162 cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa: E501
163 return shallow_to_dict(self)
164
165 def _shallow_copy_to(self, other: Self) -> None:
166 cls = self.__class__
167
168 shallow_copy: Callable[[Self, Self], None]
169 try:
170 shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
171 except KeyError:
172 shallow_copy = self._generate_shallow_copy(
173 cls._traverse_internals, "_generated_shallow_copy_traversal"
174 )
175
176 cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501
177 shallow_copy(self, other)
178
179 def _clone(self, **kw: Any) -> Self:
180 """Create a shallow copy"""
181 c = self.__class__.__new__(self.__class__)
182 self._shallow_copy_to(c)
183 return c
184
185
186class GenerativeOnTraversal(HasShallowCopy):
187 """Supplies Generative behavior but making use of traversals to shallow
188 copy.
189
190 .. seealso::
191
192 :class:`sqlalchemy.sql.base.Generative`
193
194
195 """
196
197 __slots__ = ()
198
199 def _generate(self) -> Self:
200 cls = self.__class__
201 s = cls.__new__(cls)
202 self._shallow_copy_to(s)
203 return s
204
205
206def _clone(element, **kw):
207 return element._clone()
208
209
210class HasCopyInternals(HasTraverseInternals):
211 __slots__ = ()
212
213 def _clone(self, **kw):
214 raise NotImplementedError()
215
216 def _copy_internals(
217 self, *, omit_attrs: Iterable[str] = (), **kw: Any
218 ) -> None:
219 """Reassign internal elements to be clones of themselves.
220
221 Called during a copy-and-traverse operation on newly
222 shallow-copied elements to create a deep copy.
223
224 The given clone function should be used, which may be applying
225 additional transformations to the element (i.e. replacement
226 traversal, cloned traversal, annotations).
227
228 """
229
230 try:
231 traverse_internals = self._traverse_internals
232 except AttributeError:
233 # user-defined classes may not have a _traverse_internals
234 return
235
236 for attrname, obj, meth in _copy_internals.run_generated_dispatch(
237 self, traverse_internals, "_generated_copy_internals_traversal"
238 ):
239 if attrname in omit_attrs:
240 continue
241
242 if obj is not None:
243 result = meth(attrname, self, obj, **kw)
244 if result is not None:
245 setattr(self, attrname, result)
246
247
248class _CopyInternalsTraversal(HasTraversalDispatch):
249 """Generate a _copy_internals internal traversal dispatch for classes
250 with a _traverse_internals collection."""
251
252 def visit_clauseelement(
253 self, attrname, parent, element, clone=_clone, **kw
254 ):
255 return clone(element, **kw)
256
257 def visit_clauseelement_list(
258 self, attrname, parent, element, clone=_clone, **kw
259 ):
260 return [clone(clause, **kw) for clause in element]
261
262 def visit_clauseelement_tuple(
263 self, attrname, parent, element, clone=_clone, **kw
264 ):
265 return tuple([clone(clause, **kw) for clause in element])
266
267 def visit_executable_options(
268 self, attrname, parent, element, clone=_clone, **kw
269 ):
270 return tuple([clone(clause, **kw) for clause in element])
271
272 def visit_clauseelement_unordered_set(
273 self, attrname, parent, element, clone=_clone, **kw
274 ):
275 return {clone(clause, **kw) for clause in element}
276
277 def visit_clauseelement_tuples(
278 self, attrname, parent, element, clone=_clone, **kw
279 ):
280 return [
281 tuple(clone(tup_elem, **kw) for tup_elem in elem)
282 for elem in element
283 ]
284
285 def visit_string_clauseelement_dict(
286 self, attrname, parent, element, clone=_clone, **kw
287 ):
288 return {key: clone(value, **kw) for key, value in element.items()}
289
290 def visit_setup_join_tuple(
291 self, attrname, parent, element, clone=_clone, **kw
292 ):
293 return tuple(
294 (
295 clone(target, **kw) if target is not None else None,
296 clone(onclause, **kw) if onclause is not None else None,
297 clone(from_, **kw) if from_ is not None else None,
298 flags,
299 )
300 for (target, onclause, from_, flags) in element
301 )
302
303 def visit_memoized_select_entities(self, attrname, parent, element, **kw):
304 return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
305
306 def visit_dml_ordered_values(
307 self, attrname, parent, element, clone=_clone, **kw
308 ):
309 # sequence of 2-tuples
310 return [
311 (
312 (
313 clone(key, **kw)
314 if hasattr(key, "__clause_element__")
315 else key
316 ),
317 clone(value, **kw),
318 )
319 for key, value in element
320 ]
321
322 def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
323 return {
324 (
325 clone(key, **kw) if hasattr(key, "__clause_element__") else key
326 ): clone(value, **kw)
327 for key, value in element.items()
328 }
329
330 def visit_dml_multi_values(
331 self, attrname, parent, element, clone=_clone, **kw
332 ):
333 # sequence of sequences, each sequence contains a list/dict/tuple
334
335 def copy(elem):
336 if isinstance(elem, (list, tuple)):
337 return [
338 (
339 clone(value, **kw)
340 if hasattr(value, "__clause_element__")
341 else value
342 )
343 for value in elem
344 ]
345 elif isinstance(elem, dict):
346 return {
347 (
348 clone(key, **kw)
349 if hasattr(key, "__clause_element__")
350 else key
351 ): (
352 clone(value, **kw)
353 if hasattr(value, "__clause_element__")
354 else value
355 )
356 for key, value in elem.items()
357 }
358 else:
359 # TODO: use abc classes
360 assert False
361
362 return [
363 [copy(sub_element) for sub_element in sequence]
364 for sequence in element
365 ]
366
367 def visit_propagate_attrs(
368 self, attrname, parent, element, clone=_clone, **kw
369 ):
370 return element
371
372
373_copy_internals = _CopyInternalsTraversal()
374
375
376def _flatten_clauseelement(element):
377 while hasattr(element, "__clause_element__") and not getattr(
378 element, "is_clause_element", False
379 ):
380 element = element.__clause_element__()
381
382 return element
383
384
385class _GetChildrenTraversal(HasTraversalDispatch):
386 """Generate a _children_traversal internal traversal dispatch for classes
387 with a _traverse_internals collection."""
388
389 def visit_has_cache_key(self, element, **kw):
390 # the GetChildren traversal refers explicitly to ClauseElement
391 # structures. Within these, a plain HasCacheKey is not a
392 # ClauseElement, so don't include these.
393 return ()
394
395 def visit_clauseelement(self, element, **kw):
396 return (element,)
397
398 def visit_clauseelement_list(self, element, **kw):
399 return element
400
401 def visit_clauseelement_tuple(self, element, **kw):
402 return element
403
404 def visit_clauseelement_tuples(self, element, **kw):
405 return itertools.chain.from_iterable(element)
406
407 def visit_fromclause_canonical_column_collection(self, element, **kw):
408 return ()
409
410 def visit_string_clauseelement_dict(self, element, **kw):
411 return element.values()
412
413 def visit_fromclause_ordered_set(self, element, **kw):
414 return element
415
416 def visit_clauseelement_unordered_set(self, element, **kw):
417 return element
418
419 def visit_setup_join_tuple(self, element, **kw):
420 for target, onclause, from_, flags in element:
421 if from_ is not None:
422 yield from_
423
424 if not isinstance(target, str):
425 yield _flatten_clauseelement(target)
426
427 if onclause is not None and not isinstance(onclause, str):
428 yield _flatten_clauseelement(onclause)
429
430 def visit_memoized_select_entities(self, element, **kw):
431 return self.visit_clauseelement_tuple(element, **kw)
432
433 def visit_dml_ordered_values(self, element, **kw):
434 for k, v in element:
435 if hasattr(k, "__clause_element__"):
436 yield k
437 yield v
438
439 def visit_dml_values(self, element, **kw):
440 expr_values = {k for k in element if hasattr(k, "__clause_element__")}
441 str_values = expr_values.symmetric_difference(element)
442
443 for k in sorted(str_values):
444 yield element[k]
445 for k in expr_values:
446 yield k
447 yield element[k]
448
449 def visit_dml_multi_values(self, element, **kw):
450 return ()
451
452 def visit_propagate_attrs(self, element, **kw):
453 return ()
454
455
456_get_children = _GetChildrenTraversal()
457
458
459@util.preload_module("sqlalchemy.sql.elements")
460def _resolve_name_for_compare(element, name, anon_map, **kw):
461 if isinstance(name, util.preloaded.sql_elements._anonymous_label):
462 name = name.apply_map(anon_map)
463
464 return name
465
466
467class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
468 __slots__ = "stack", "cache", "anon_map"
469
470 def __init__(self):
471 self.stack: Deque[
472 Tuple[
473 Optional[ExternallyTraversible],
474 Optional[ExternallyTraversible],
475 ]
476 ] = deque()
477 self.cache = set()
478
479 def _memoized_attr_anon_map(self):
480 return (anon_map(), anon_map())
481
482 def compare(
483 self,
484 obj1: ExternallyTraversible,
485 obj2: ExternallyTraversible,
486 **kw: Any,
487 ) -> bool:
488 stack = self.stack
489 cache = self.cache
490
491 compare_annotations = kw.get("compare_annotations", False)
492
493 stack.append((obj1, obj2))
494
495 while stack:
496 left, right = stack.popleft()
497
498 if left is right:
499 continue
500 elif left is None or right is None:
501 # we know they are different so no match
502 return False
503 elif (left, right) in cache:
504 continue
505 cache.add((left, right))
506
507 visit_name = left.__visit_name__
508 if visit_name != right.__visit_name__:
509 return False
510
511 meth = getattr(self, "compare_%s" % visit_name, None)
512
513 if meth:
514 attributes_compared = meth(left, right, **kw)
515 if attributes_compared is COMPARE_FAILED:
516 return False
517 elif attributes_compared is SKIP_TRAVERSE:
518 continue
519
520 # attributes_compared is returned as a list of attribute
521 # names that were "handled" by the comparison method above.
522 # remaining attribute names in the _traverse_internals
523 # will be compared.
524 else:
525 attributes_compared = ()
526
527 for (
528 (left_attrname, left_visit_sym),
529 (right_attrname, right_visit_sym),
530 ) in zip_longest(
531 left._traverse_internals,
532 right._traverse_internals,
533 fillvalue=(None, None),
534 ):
535 if not compare_annotations and (
536 (left_attrname == "_annotations")
537 or (right_attrname == "_annotations")
538 ):
539 continue
540
541 if (
542 left_attrname != right_attrname
543 or left_visit_sym is not right_visit_sym
544 ):
545 return False
546 elif left_attrname in attributes_compared:
547 continue
548
549 assert left_visit_sym is not None
550 assert left_attrname is not None
551 assert right_attrname is not None
552
553 dispatch = self.dispatch(left_visit_sym)
554 assert dispatch is not None, (
555 f"{self.__class__} has no dispatch for "
556 f"'{self._dispatch_lookup[left_visit_sym]}'"
557 )
558 left_child = operator.attrgetter(left_attrname)(left)
559 right_child = operator.attrgetter(right_attrname)(right)
560 if left_child is None:
561 if right_child is not None:
562 return False
563 else:
564 continue
565 elif right_child is None:
566 return False
567
568 comparison = dispatch(
569 left_attrname, left, left_child, right, right_child, **kw
570 )
571 if comparison is COMPARE_FAILED:
572 return False
573
574 return True
575
576 def compare_inner(self, obj1, obj2, **kw):
577 comparator = self.__class__()
578 return comparator.compare(obj1, obj2, **kw)
579
580 def visit_has_cache_key(
581 self, attrname, left_parent, left, right_parent, right, **kw
582 ):
583 if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
584 self.anon_map[1], []
585 ):
586 return COMPARE_FAILED
587
588 def visit_propagate_attrs(
589 self, attrname, left_parent, left, right_parent, right, **kw
590 ):
591 return self.compare_inner(
592 left.get("plugin_subject", None), right.get("plugin_subject", None)
593 )
594
595 def visit_has_cache_key_list(
596 self, attrname, left_parent, left, right_parent, right, **kw
597 ):
598 for l, r in zip_longest(left, right, fillvalue=None):
599 if l is None:
600 if r is not None:
601 return COMPARE_FAILED
602 else:
603 continue
604 elif r is None:
605 return COMPARE_FAILED
606
607 if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
608 self.anon_map[1], []
609 ):
610 return COMPARE_FAILED
611
612 def visit_executable_options(
613 self, attrname, left_parent, left, right_parent, right, **kw
614 ):
615 for l, r in zip_longest(left, right, fillvalue=None):
616 if l is None:
617 if r is not None:
618 return COMPARE_FAILED
619 else:
620 continue
621 elif r is None:
622 return COMPARE_FAILED
623
624 if (
625 l._gen_cache_key(self.anon_map[0], [])
626 if l._is_has_cache_key
627 else l
628 ) != (
629 r._gen_cache_key(self.anon_map[1], [])
630 if r._is_has_cache_key
631 else r
632 ):
633 return COMPARE_FAILED
634
635 def visit_clauseelement(
636 self, attrname, left_parent, left, right_parent, right, **kw
637 ):
638 self.stack.append((left, right))
639
640 def visit_fromclause_canonical_column_collection(
641 self, attrname, left_parent, left, right_parent, right, **kw
642 ):
643 for lcol, rcol in zip_longest(left, right, fillvalue=None):
644 self.stack.append((lcol, rcol))
645
646 def visit_fromclause_derived_column_collection(
647 self, attrname, left_parent, left, right_parent, right, **kw
648 ):
649 pass
650
651 def visit_string_clauseelement_dict(
652 self, attrname, left_parent, left, right_parent, right, **kw
653 ):
654 for lstr, rstr in zip_longest(
655 sorted(left), sorted(right), fillvalue=None
656 ):
657 if lstr != rstr:
658 return COMPARE_FAILED
659 self.stack.append((left[lstr], right[rstr]))
660
661 def visit_clauseelement_tuples(
662 self, attrname, left_parent, left, right_parent, right, **kw
663 ):
664 for ltup, rtup in zip_longest(left, right, fillvalue=None):
665 if ltup is None or rtup is None:
666 return COMPARE_FAILED
667
668 for l, r in zip_longest(ltup, rtup, fillvalue=None):
669 self.stack.append((l, r))
670
671 def visit_multi_list(
672 self, attrname, left_parent, left, right_parent, right, **kw
673 ):
674 for l, r in zip_longest(left, right, fillvalue=None):
675 if isinstance(l, str):
676 if not isinstance(r, str) or l != r:
677 return COMPARE_FAILED
678 elif isinstance(r, str):
679 if not isinstance(l, str) or l != r:
680 return COMPARE_FAILED
681 else:
682 self.stack.append((l, r))
683
684 def visit_clauseelement_list(
685 self, attrname, left_parent, left, right_parent, right, **kw
686 ):
687 for l, r in zip_longest(left, right, fillvalue=None):
688 self.stack.append((l, r))
689
690 def visit_clauseelement_tuple(
691 self, attrname, left_parent, left, right_parent, right, **kw
692 ):
693 for l, r in zip_longest(left, right, fillvalue=None):
694 self.stack.append((l, r))
695
696 def _compare_unordered_sequences(self, seq1, seq2, **kw):
697 if seq1 is None:
698 return seq2 is None
699
700 completed: Set[object] = set()
701 for clause in seq1:
702 for other_clause in set(seq2).difference(completed):
703 if self.compare_inner(clause, other_clause, **kw):
704 completed.add(other_clause)
705 break
706 return len(completed) == len(seq1) == len(seq2)
707
708 def visit_clauseelement_unordered_set(
709 self, attrname, left_parent, left, right_parent, right, **kw
710 ):
711 return self._compare_unordered_sequences(left, right, **kw)
712
713 def visit_fromclause_ordered_set(
714 self, attrname, left_parent, left, right_parent, right, **kw
715 ):
716 for l, r in zip_longest(left, right, fillvalue=None):
717 self.stack.append((l, r))
718
719 def visit_string(
720 self, attrname, left_parent, left, right_parent, right, **kw
721 ):
722 return left == right
723
724 def visit_string_list(
725 self, attrname, left_parent, left, right_parent, right, **kw
726 ):
727 return left == right
728
729 def visit_string_multi_dict(
730 self, attrname, left_parent, left, right_parent, right, **kw
731 ):
732 for lk, rk in zip_longest(
733 sorted(left.keys()), sorted(right.keys()), fillvalue=(None, None)
734 ):
735 if lk != rk:
736 return COMPARE_FAILED
737
738 lv, rv = left[lk], right[rk]
739
740 lhc = isinstance(left, HasCacheKey)
741 rhc = isinstance(right, HasCacheKey)
742 if lhc and rhc:
743 if lv._gen_cache_key(
744 self.anon_map[0], []
745 ) != rv._gen_cache_key(self.anon_map[1], []):
746 return COMPARE_FAILED
747 elif lhc != rhc:
748 return COMPARE_FAILED
749 elif lv != rv:
750 return COMPARE_FAILED
751
752 def visit_multi(
753 self, attrname, left_parent, left, right_parent, right, **kw
754 ):
755 lhc = isinstance(left, HasCacheKey)
756 rhc = isinstance(right, HasCacheKey)
757 if lhc and rhc:
758 if left._gen_cache_key(
759 self.anon_map[0], []
760 ) != right._gen_cache_key(self.anon_map[1], []):
761 return COMPARE_FAILED
762 elif lhc != rhc:
763 return COMPARE_FAILED
764 else:
765 return left == right
766
767 def visit_anon_name(
768 self, attrname, left_parent, left, right_parent, right, **kw
769 ):
770 return _resolve_name_for_compare(
771 left_parent, left, self.anon_map[0], **kw
772 ) == _resolve_name_for_compare(
773 right_parent, right, self.anon_map[1], **kw
774 )
775
776 def visit_boolean(
777 self, attrname, left_parent, left, right_parent, right, **kw
778 ):
779 return left == right
780
781 def visit_operator(
782 self, attrname, left_parent, left, right_parent, right, **kw
783 ):
784 return left == right
785
786 def visit_type(
787 self, attrname, left_parent, left, right_parent, right, **kw
788 ):
789 return left._compare_type_affinity(right)
790
791 def visit_plain_dict(
792 self, attrname, left_parent, left, right_parent, right, **kw
793 ):
794 return left == right
795
796 def visit_dialect_options(
797 self, attrname, left_parent, left, right_parent, right, **kw
798 ):
799 return left == right
800
801 def visit_annotations_key(
802 self, attrname, left_parent, left, right_parent, right, **kw
803 ):
804 if left and right:
805 return (
806 left_parent._annotations_cache_key
807 == right_parent._annotations_cache_key
808 )
809 else:
810 return left == right
811
812 def visit_compile_state_funcs(
813 self, attrname, left_parent, left, right_parent, right, **kw
814 ):
815 return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
816 (fn.__code__, c_key) for fn, c_key in right
817 )
818
819 def visit_plain_obj(
820 self, attrname, left_parent, left, right_parent, right, **kw
821 ):
822 return left == right
823
824 def visit_named_ddl_element(
825 self, attrname, left_parent, left, right_parent, right, **kw
826 ):
827 if left is None:
828 if right is not None:
829 return COMPARE_FAILED
830
831 return left.name == right.name
832
833 def visit_prefix_sequence(
834 self, attrname, left_parent, left, right_parent, right, **kw
835 ):
836 for (l_clause, l_str), (r_clause, r_str) in zip_longest(
837 left, right, fillvalue=(None, None)
838 ):
839 if l_str != r_str:
840 return COMPARE_FAILED
841 else:
842 self.stack.append((l_clause, r_clause))
843
844 def visit_setup_join_tuple(
845 self, attrname, left_parent, left, right_parent, right, **kw
846 ):
847 # TODO: look at attrname for "legacy_join" and use different structure
848 for (
849 (l_target, l_onclause, l_from, l_flags),
850 (r_target, r_onclause, r_from, r_flags),
851 ) in zip_longest(left, right, fillvalue=(None, None, None, None)):
852 if l_flags != r_flags:
853 return COMPARE_FAILED
854 self.stack.append((l_target, r_target))
855 self.stack.append((l_onclause, r_onclause))
856 self.stack.append((l_from, r_from))
857
858 def visit_memoized_select_entities(
859 self, attrname, left_parent, left, right_parent, right, **kw
860 ):
861 return self.visit_clauseelement_tuple(
862 attrname, left_parent, left, right_parent, right, **kw
863 )
864
865 def visit_table_hint_list(
866 self, attrname, left_parent, left, right_parent, right, **kw
867 ):
868 left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
869 right_keys = sorted(
870 right, key=lambda elem: (elem[0].fullname, elem[1])
871 )
872 for (ltable, ldialect), (rtable, rdialect) in zip_longest(
873 left_keys, right_keys, fillvalue=(None, None)
874 ):
875 if ldialect != rdialect:
876 return COMPARE_FAILED
877 elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
878 return COMPARE_FAILED
879 else:
880 self.stack.append((ltable, rtable))
881
882 def visit_statement_hint_list(
883 self, attrname, left_parent, left, right_parent, right, **kw
884 ):
885 return left == right
886
887 def visit_unknown_structure(
888 self, attrname, left_parent, left, right_parent, right, **kw
889 ):
890 raise NotImplementedError()
891
892 def visit_dml_ordered_values(
893 self, attrname, left_parent, left, right_parent, right, **kw
894 ):
895 # sequence of tuple pairs
896
897 for (lk, lv), (rk, rv) in zip_longest(
898 left, right, fillvalue=(None, None)
899 ):
900 if not self._compare_dml_values_or_ce(lk, rk, **kw):
901 return COMPARE_FAILED
902
903 def _compare_dml_values_or_ce(self, lv, rv, **kw):
904 lvce = hasattr(lv, "__clause_element__")
905 rvce = hasattr(rv, "__clause_element__")
906 if lvce != rvce:
907 return False
908 elif lvce and not self.compare_inner(lv, rv, **kw):
909 return False
910 elif not lvce and lv != rv:
911 return False
912 elif not self.compare_inner(lv, rv, **kw):
913 return False
914
915 return True
916
917 def visit_dml_values(
918 self, attrname, left_parent, left, right_parent, right, **kw
919 ):
920 if left is None or right is None or len(left) != len(right):
921 return COMPARE_FAILED
922
923 if isinstance(left, collections_abc.Sequence):
924 for lv, rv in zip(left, right):
925 if not self._compare_dml_values_or_ce(lv, rv, **kw):
926 return COMPARE_FAILED
927 elif isinstance(right, collections_abc.Sequence):
928 return COMPARE_FAILED
929 else:
930 # dictionaries guaranteed to support insert ordering in
931 # py37 so that we can compare the keys in order. without
932 # this, we can't compare SQL expression keys because we don't
933 # know which key is which
934 for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
935 if not self._compare_dml_values_or_ce(lk, rk, **kw):
936 return COMPARE_FAILED
937 if not self._compare_dml_values_or_ce(lv, rv, **kw):
938 return COMPARE_FAILED
939
940 def visit_dml_multi_values(
941 self, attrname, left_parent, left, right_parent, right, **kw
942 ):
943 for lseq, rseq in zip_longest(left, right, fillvalue=None):
944 if lseq is None or rseq is None:
945 return COMPARE_FAILED
946
947 for ld, rd in zip_longest(lseq, rseq, fillvalue=None):
948 if (
949 self.visit_dml_values(
950 attrname, left_parent, ld, right_parent, rd, **kw
951 )
952 is COMPARE_FAILED
953 ):
954 return COMPARE_FAILED
955
956 def compare_expression_clauselist(self, left, right, **kw):
957 if left.operator is right.operator:
958 if operators.is_associative(left.operator):
959 if self._compare_unordered_sequences(
960 left.clauses, right.clauses, **kw
961 ):
962 return ["operator", "clauses"]
963 else:
964 return COMPARE_FAILED
965 else:
966 return ["operator"]
967 else:
968 return COMPARE_FAILED
969
970 def compare_clauselist(self, left, right, **kw):
971 return self.compare_expression_clauselist(left, right, **kw)
972
973 def compare_binary(self, left, right, **kw):
974 if left.operator == right.operator:
975 if operators.is_commutative(left.operator):
976 if (
977 self.compare_inner(left.left, right.left, **kw)
978 and self.compare_inner(left.right, right.right, **kw)
979 ) or (
980 self.compare_inner(left.left, right.right, **kw)
981 and self.compare_inner(left.right, right.left, **kw)
982 ):
983 return ["operator", "negate", "left", "right"]
984 else:
985 return COMPARE_FAILED
986 else:
987 return ["operator", "negate"]
988 else:
989 return COMPARE_FAILED
990
991 def compare_bindparam(self, left, right, **kw):
992 compare_keys = kw.pop("compare_keys", True)
993 compare_values = kw.pop("compare_values", True)
994
995 if compare_values:
996 omit = []
997 else:
998 # this means, "skip these, we already compared"
999 omit = ["callable", "value"]
1000
1001 if not compare_keys:
1002 omit.append("key")
1003
1004 return omit
1005
1006
1007class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
1008 def compare_column_element(
1009 self, left, right, use_proxies=True, equivalents=(), **kw
1010 ):
1011 """Compare ColumnElements using proxies and equivalent collections.
1012
1013 This is a comparison strategy specific to the ORM.
1014 """
1015
1016 to_compare = (right,)
1017 if equivalents and right in equivalents:
1018 to_compare = equivalents[right].union(to_compare)
1019
1020 for oth in to_compare:
1021 if use_proxies and left.shares_lineage(oth):
1022 return SKIP_TRAVERSE
1023 elif hash(left) == hash(right):
1024 return SKIP_TRAVERSE
1025 else:
1026 return COMPARE_FAILED
1027
1028 def compare_column(self, left, right, **kw):
1029 return self.compare_column_element(left, right, **kw)
1030
1031 def compare_label(self, left, right, **kw):
1032 return self.compare_column_element(left, right, **kw)
1033
1034 def compare_table(self, left, right, **kw):
1035 # tables compare on identity, since it's not really feasible to
1036 # compare them column by column with the above rules
1037 return SKIP_TRAVERSE if left is right else COMPARE_FAILED