1# sql/traversals.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11from collections import deque
12import collections.abc as collections_abc
13import itertools
14from itertools import zip_longest
15import operator
16import typing
17from typing import Any
18from typing import Callable
19from typing import Deque
20from typing import Dict
21from typing import Iterable
22from typing import Optional
23from typing import Set
24from typing import Tuple
25from typing import Type
26
27from . import operators
28from .cache_key import HasCacheKey
29from .visitors import _TraverseInternalsType
30from .visitors import anon_map
31from .visitors import ExternallyTraversible
32from .visitors import HasTraversalDispatch
33from .visitors import HasTraverseInternals
34from .. import util
35from ..util import langhelpers
36from ..util.typing import Self
37
38SKIP_TRAVERSE = util.symbol("skip_traverse")
39COMPARE_FAILED = False
40COMPARE_SUCCEEDED = True
41
42
43def compare(obj1: Any, obj2: Any, **kw: Any) -> bool:
44 strategy: TraversalComparatorStrategy
45 if kw.get("use_proxies", False):
46 strategy = ColIdentityComparatorStrategy()
47 else:
48 strategy = TraversalComparatorStrategy()
49
50 return strategy.compare(obj1, obj2, **kw)
51
52
53def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None:
54 for cls in util.walk_subclasses(target_hierarchy):
55 if hasattr(cls, "_generate_cache_attrs") and hasattr(
56 cls, "_traverse_internals"
57 ):
58 cls._generate_cache_attrs()
59 _copy_internals.generate_dispatch(
60 cls,
61 cls._traverse_internals,
62 "_generated_copy_internals_traversal",
63 )
64 _get_children.generate_dispatch(
65 cls,
66 cls._traverse_internals,
67 "_generated_get_children_traversal",
68 )
69
70
71class HasShallowCopy(HasTraverseInternals):
72 """attribute-wide operations that are useful for classes that use
73 __slots__ and therefore can't operate on their attributes in a dictionary.
74
75
76 """
77
78 __slots__ = ()
79
80 if typing.TYPE_CHECKING:
81
82 def _generated_shallow_copy_traversal(self, other: Self) -> None: ...
83
84 def _generated_shallow_from_dict_traversal(
85 self, d: Dict[str, Any]
86 ) -> None: ...
87
88 def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: ...
89
90 @classmethod
91 def _generate_shallow_copy(
92 cls,
93 internal_dispatch: _TraverseInternalsType,
94 method_name: str,
95 ) -> Callable[[Self, Self], None]:
96 code = "\n".join(
97 f" other.{attrname} = self.{attrname}"
98 for attrname, _ in internal_dispatch
99 )
100 meth_text = f"def {method_name}(self, other):\n{code}\n"
101 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
102
103 @classmethod
104 def _generate_shallow_to_dict(
105 cls,
106 internal_dispatch: _TraverseInternalsType,
107 method_name: str,
108 ) -> Callable[[Self], Dict[str, Any]]:
109 code = ",\n".join(
110 f" '{attrname}': self.{attrname}"
111 for attrname, _ in internal_dispatch
112 )
113 meth_text = f"def {method_name}(self):\n return {{{code}}}\n"
114 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
115
116 @classmethod
117 def _generate_shallow_from_dict(
118 cls,
119 internal_dispatch: _TraverseInternalsType,
120 method_name: str,
121 ) -> Callable[[Self, Dict[str, Any]], None]:
122 code = "\n".join(
123 f" self.{attrname} = d['{attrname}']"
124 for attrname, _ in internal_dispatch
125 )
126 meth_text = f"def {method_name}(self, d):\n{code}\n"
127 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
128
129 def _shallow_from_dict(self, d: Dict[str, Any]) -> None:
130 cls = self.__class__
131
132 shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None]
133 try:
134 shallow_from_dict = cls.__dict__[
135 "_generated_shallow_from_dict_traversal"
136 ]
137 except KeyError:
138 shallow_from_dict = self._generate_shallow_from_dict(
139 cls._traverse_internals,
140 "_generated_shallow_from_dict_traversal",
141 )
142
143 cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa: E501
144
145 shallow_from_dict(self, d)
146
147 def _shallow_to_dict(self) -> Dict[str, Any]:
148 cls = self.__class__
149
150 shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]]
151
152 try:
153 shallow_to_dict = cls.__dict__[
154 "_generated_shallow_to_dict_traversal"
155 ]
156 except KeyError:
157 shallow_to_dict = self._generate_shallow_to_dict(
158 cls._traverse_internals, "_generated_shallow_to_dict_traversal"
159 )
160
161 cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa: E501
162 return shallow_to_dict(self)
163
164 def _shallow_copy_to(self, other: Self) -> None:
165 cls = self.__class__
166
167 shallow_copy: Callable[[Self, Self], None]
168 try:
169 shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
170 except KeyError:
171 shallow_copy = self._generate_shallow_copy(
172 cls._traverse_internals, "_generated_shallow_copy_traversal"
173 )
174
175 cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501
176 shallow_copy(self, other)
177
178 def _clone(self, **kw: Any) -> Self:
179 """Create a shallow copy"""
180 c = self.__class__.__new__(self.__class__)
181 self._shallow_copy_to(c)
182 return c
183
184
185class GenerativeOnTraversal(HasShallowCopy):
186 """Supplies Generative behavior but making use of traversals to shallow
187 copy.
188
189 .. seealso::
190
191 :class:`sqlalchemy.sql.base.Generative`
192
193
194 """
195
196 __slots__ = ()
197
198 def _generate(self) -> Self:
199 cls = self.__class__
200 s = cls.__new__(cls)
201 self._shallow_copy_to(s)
202 return s
203
204
205def _clone(element, **kw):
206 return element._clone()
207
208
209class HasCopyInternals(HasTraverseInternals):
210 __slots__ = ()
211
212 def _clone(self, **kw):
213 raise NotImplementedError()
214
215 def _copy_internals(
216 self, *, omit_attrs: Iterable[str] = (), **kw: Any
217 ) -> None:
218 """Reassign internal elements to be clones of themselves.
219
220 Called during a copy-and-traverse operation on newly
221 shallow-copied elements to create a deep copy.
222
223 The given clone function should be used, which may be applying
224 additional transformations to the element (i.e. replacement
225 traversal, cloned traversal, annotations).
226
227 """
228
229 try:
230 traverse_internals = self._traverse_internals
231 except AttributeError:
232 # user-defined classes may not have a _traverse_internals
233 return
234
235 for attrname, obj, meth in _copy_internals.run_generated_dispatch(
236 self, traverse_internals, "_generated_copy_internals_traversal"
237 ):
238 if attrname in omit_attrs:
239 continue
240
241 if obj is not None:
242 result = meth(attrname, self, obj, **kw)
243 if result is not None:
244 setattr(self, attrname, result)
245
246
247class _CopyInternalsTraversal(HasTraversalDispatch):
248 """Generate a _copy_internals internal traversal dispatch for classes
249 with a _traverse_internals collection."""
250
251 def visit_clauseelement(
252 self, attrname, parent, element, clone=_clone, **kw
253 ):
254 return clone(element, **kw)
255
256 def visit_clauseelement_list(
257 self, attrname, parent, element, clone=_clone, **kw
258 ):
259 return [clone(clause, **kw) for clause in element]
260
261 def visit_clauseelement_tuple(
262 self, attrname, parent, element, clone=_clone, **kw
263 ):
264 return tuple([clone(clause, **kw) for clause in element])
265
266 def visit_executable_options(
267 self, attrname, parent, element, clone=_clone, **kw
268 ):
269 return tuple([clone(clause, **kw) for clause in element])
270
271 def visit_clauseelement_unordered_set(
272 self, attrname, parent, element, clone=_clone, **kw
273 ):
274 return {clone(clause, **kw) for clause in element}
275
276 def visit_clauseelement_tuples(
277 self, attrname, parent, element, clone=_clone, **kw
278 ):
279 return [
280 tuple(clone(tup_elem, **kw) for tup_elem in elem)
281 for elem in element
282 ]
283
284 def visit_string_clauseelement_dict(
285 self, attrname, parent, element, clone=_clone, **kw
286 ):
287 return {key: clone(value, **kw) for key, value in element.items()}
288
289 def visit_setup_join_tuple(
290 self, attrname, parent, element, clone=_clone, **kw
291 ):
292 return tuple(
293 (
294 clone(target, **kw) if target is not None else None,
295 clone(onclause, **kw) if onclause is not None else None,
296 clone(from_, **kw) if from_ is not None else None,
297 flags,
298 )
299 for (target, onclause, from_, flags) in element
300 )
301
302 def visit_memoized_select_entities(self, attrname, parent, element, **kw):
303 return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
304
305 def visit_dml_ordered_values(
306 self, attrname, parent, element, clone=_clone, **kw
307 ):
308 # sequence of 2-tuples
309 return [
310 (
311 (
312 clone(key, **kw)
313 if hasattr(key, "__clause_element__")
314 else key
315 ),
316 clone(value, **kw),
317 )
318 for key, value in element
319 ]
320
321 def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
322 return {
323 (
324 clone(key, **kw) if hasattr(key, "__clause_element__") else key
325 ): clone(value, **kw)
326 for key, value in element.items()
327 }
328
329 def visit_dml_multi_values(
330 self, attrname, parent, element, clone=_clone, **kw
331 ):
332 # sequence of sequences, each sequence contains a list/dict/tuple
333
334 def copy(elem):
335 if isinstance(elem, (list, tuple)):
336 return [
337 (
338 clone(value, **kw)
339 if hasattr(value, "__clause_element__")
340 else value
341 )
342 for value in elem
343 ]
344 elif isinstance(elem, dict):
345 return {
346 (
347 clone(key, **kw)
348 if hasattr(key, "__clause_element__")
349 else key
350 ): (
351 clone(value, **kw)
352 if hasattr(value, "__clause_element__")
353 else value
354 )
355 for key, value in elem.items()
356 }
357 else:
358 # TODO: use abc classes
359 assert False
360
361 return [
362 [copy(sub_element) for sub_element in sequence]
363 for sequence in element
364 ]
365
366 def visit_propagate_attrs(
367 self, attrname, parent, element, clone=_clone, **kw
368 ):
369 return element
370
371
372_copy_internals = _CopyInternalsTraversal()
373
374
375def _flatten_clauseelement(element):
376 while hasattr(element, "__clause_element__") and not getattr(
377 element, "is_clause_element", False
378 ):
379 element = element.__clause_element__()
380
381 return element
382
383
384class _GetChildrenTraversal(HasTraversalDispatch):
385 """Generate a _children_traversal internal traversal dispatch for classes
386 with a _traverse_internals collection."""
387
388 def visit_has_cache_key(self, element, **kw):
389 # the GetChildren traversal refers explicitly to ClauseElement
390 # structures. Within these, a plain HasCacheKey is not a
391 # ClauseElement, so don't include these.
392 return ()
393
394 def visit_clauseelement(self, element, **kw):
395 return (element,)
396
397 def visit_clauseelement_list(self, element, **kw):
398 return element
399
400 def visit_clauseelement_tuple(self, element, **kw):
401 return element
402
403 def visit_clauseelement_tuples(self, element, **kw):
404 return itertools.chain.from_iterable(element)
405
406 def visit_fromclause_canonical_column_collection(self, element, **kw):
407 return ()
408
409 def visit_string_clauseelement_dict(self, element, **kw):
410 return element.values()
411
412 def visit_fromclause_ordered_set(self, element, **kw):
413 return element
414
415 def visit_clauseelement_unordered_set(self, element, **kw):
416 return element
417
418 def visit_setup_join_tuple(self, element, **kw):
419 for target, onclause, from_, flags in element:
420 if from_ is not None:
421 yield from_
422
423 if not isinstance(target, str):
424 yield _flatten_clauseelement(target)
425
426 if onclause is not None and not isinstance(onclause, str):
427 yield _flatten_clauseelement(onclause)
428
429 def visit_memoized_select_entities(self, element, **kw):
430 return self.visit_clauseelement_tuple(element, **kw)
431
432 def visit_dml_ordered_values(self, element, **kw):
433 for k, v in element:
434 if hasattr(k, "__clause_element__"):
435 yield k
436 yield v
437
438 def visit_dml_values(self, element, **kw):
439 expr_values = {k for k in element if hasattr(k, "__clause_element__")}
440 str_values = expr_values.symmetric_difference(element)
441
442 for k in sorted(str_values):
443 yield element[k]
444 for k in expr_values:
445 yield k
446 yield element[k]
447
448 def visit_dml_multi_values(self, element, **kw):
449 return ()
450
451 def visit_propagate_attrs(self, element, **kw):
452 return ()
453
454
455_get_children = _GetChildrenTraversal()
456
457
458@util.preload_module("sqlalchemy.sql.elements")
459def _resolve_name_for_compare(element, name, anon_map, **kw):
460 if isinstance(name, util.preloaded.sql_elements._anonymous_label):
461 name = name.apply_map(anon_map)
462
463 return name
464
465
466class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
467 __slots__ = "stack", "cache", "anon_map"
468
469 def __init__(self):
470 self.stack: Deque[
471 Tuple[
472 Optional[ExternallyTraversible],
473 Optional[ExternallyTraversible],
474 ]
475 ] = deque()
476 self.cache = set()
477
478 def _memoized_attr_anon_map(self):
479 return (anon_map(), anon_map())
480
481 def compare(
482 self,
483 obj1: ExternallyTraversible,
484 obj2: ExternallyTraversible,
485 **kw: Any,
486 ) -> bool:
487 stack = self.stack
488 cache = self.cache
489
490 compare_annotations = kw.get("compare_annotations", False)
491
492 stack.append((obj1, obj2))
493
494 while stack:
495 left, right = stack.popleft()
496
497 if left is right:
498 continue
499 elif left is None or right is None:
500 # we know they are different so no match
501 return False
502 elif (left, right) in cache:
503 continue
504 cache.add((left, right))
505
506 visit_name = left.__visit_name__
507 if visit_name != right.__visit_name__:
508 return False
509
510 meth = getattr(self, "compare_%s" % visit_name, None)
511
512 if meth:
513 attributes_compared = meth(left, right, **kw)
514 if attributes_compared is COMPARE_FAILED:
515 return False
516 elif attributes_compared is SKIP_TRAVERSE:
517 continue
518
519 # attributes_compared is returned as a list of attribute
520 # names that were "handled" by the comparison method above.
521 # remaining attribute names in the _traverse_internals
522 # will be compared.
523 else:
524 attributes_compared = ()
525
526 for (
527 (left_attrname, left_visit_sym),
528 (right_attrname, right_visit_sym),
529 ) in zip_longest(
530 left._traverse_internals,
531 right._traverse_internals,
532 fillvalue=(None, None),
533 ):
534 if not compare_annotations and (
535 (left_attrname == "_annotations")
536 or (right_attrname == "_annotations")
537 ):
538 continue
539
540 if (
541 left_attrname != right_attrname
542 or left_visit_sym is not right_visit_sym
543 ):
544 return False
545 elif left_attrname in attributes_compared:
546 continue
547
548 assert left_visit_sym is not None
549 assert left_attrname is not None
550 assert right_attrname is not None
551
552 dispatch = self.dispatch(left_visit_sym)
553 assert dispatch is not None, (
554 f"{self.__class__} has no dispatch for "
555 f"'{self._dispatch_lookup[left_visit_sym]}'"
556 )
557 left_child = operator.attrgetter(left_attrname)(left)
558 right_child = operator.attrgetter(right_attrname)(right)
559 if left_child is None:
560 if right_child is not None:
561 return False
562 else:
563 continue
564 elif right_child is None:
565 return False
566
567 comparison = dispatch(
568 left_attrname, left, left_child, right, right_child, **kw
569 )
570 if comparison is COMPARE_FAILED:
571 return False
572
573 return True
574
575 def compare_inner(self, obj1, obj2, **kw):
576 comparator = self.__class__()
577 return comparator.compare(obj1, obj2, **kw)
578
579 def visit_has_cache_key(
580 self, attrname, left_parent, left, right_parent, right, **kw
581 ):
582 if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
583 self.anon_map[1], []
584 ):
585 return COMPARE_FAILED
586
587 def visit_propagate_attrs(
588 self, attrname, left_parent, left, right_parent, right, **kw
589 ):
590 return self.compare_inner(
591 left.get("plugin_subject", None), right.get("plugin_subject", None)
592 )
593
594 def visit_has_cache_key_list(
595 self, attrname, left_parent, left, right_parent, right, **kw
596 ):
597 for l, r in zip_longest(left, right, fillvalue=None):
598 if l is None:
599 if r is not None:
600 return COMPARE_FAILED
601 else:
602 continue
603 elif r is None:
604 return COMPARE_FAILED
605
606 if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
607 self.anon_map[1], []
608 ):
609 return COMPARE_FAILED
610
611 def visit_executable_options(
612 self, attrname, left_parent, left, right_parent, right, **kw
613 ):
614 for l, r in zip_longest(left, right, fillvalue=None):
615 if l is None:
616 if r is not None:
617 return COMPARE_FAILED
618 else:
619 continue
620 elif r is None:
621 return COMPARE_FAILED
622
623 if (
624 l._gen_cache_key(self.anon_map[0], [])
625 if l._is_has_cache_key
626 else l
627 ) != (
628 r._gen_cache_key(self.anon_map[1], [])
629 if r._is_has_cache_key
630 else r
631 ):
632 return COMPARE_FAILED
633
634 def visit_clauseelement(
635 self, attrname, left_parent, left, right_parent, right, **kw
636 ):
637 self.stack.append((left, right))
638
639 def visit_fromclause_canonical_column_collection(
640 self, attrname, left_parent, left, right_parent, right, **kw
641 ):
642 for lcol, rcol in zip_longest(left, right, fillvalue=None):
643 self.stack.append((lcol, rcol))
644
645 def visit_fromclause_derived_column_collection(
646 self, attrname, left_parent, left, right_parent, right, **kw
647 ):
648 pass
649
650 def visit_string_clauseelement_dict(
651 self, attrname, left_parent, left, right_parent, right, **kw
652 ):
653 for lstr, rstr in zip_longest(
654 sorted(left), sorted(right), fillvalue=None
655 ):
656 if lstr != rstr:
657 return COMPARE_FAILED
658 self.stack.append((left[lstr], right[rstr]))
659
660 def visit_clauseelement_tuples(
661 self, attrname, left_parent, left, right_parent, right, **kw
662 ):
663 for ltup, rtup in zip_longest(left, right, fillvalue=None):
664 if ltup is None or rtup is None:
665 return COMPARE_FAILED
666
667 for l, r in zip_longest(ltup, rtup, fillvalue=None):
668 self.stack.append((l, r))
669
670 def visit_clauseelement_list(
671 self, attrname, left_parent, left, right_parent, right, **kw
672 ):
673 for l, r in zip_longest(left, right, fillvalue=None):
674 self.stack.append((l, r))
675
676 def visit_clauseelement_tuple(
677 self, attrname, left_parent, left, right_parent, right, **kw
678 ):
679 for l, r in zip_longest(left, right, fillvalue=None):
680 self.stack.append((l, r))
681
682 def _compare_unordered_sequences(self, seq1, seq2, **kw):
683 if seq1 is None:
684 return seq2 is None
685
686 completed: Set[object] = set()
687 for clause in seq1:
688 for other_clause in set(seq2).difference(completed):
689 if self.compare_inner(clause, other_clause, **kw):
690 completed.add(other_clause)
691 break
692 return len(completed) == len(seq1) == len(seq2)
693
694 def visit_clauseelement_unordered_set(
695 self, attrname, left_parent, left, right_parent, right, **kw
696 ):
697 return self._compare_unordered_sequences(left, right, **kw)
698
699 def visit_fromclause_ordered_set(
700 self, attrname, left_parent, left, right_parent, right, **kw
701 ):
702 for l, r in zip_longest(left, right, fillvalue=None):
703 self.stack.append((l, r))
704
705 def visit_string(
706 self, attrname, left_parent, left, right_parent, right, **kw
707 ):
708 return left == right
709
710 def visit_string_list(
711 self, attrname, left_parent, left, right_parent, right, **kw
712 ):
713 return left == right
714
715 def visit_string_multi_dict(
716 self, attrname, left_parent, left, right_parent, right, **kw
717 ):
718 for lk, rk in zip_longest(
719 sorted(left.keys()), sorted(right.keys()), fillvalue=(None, None)
720 ):
721 if lk != rk:
722 return COMPARE_FAILED
723
724 lv, rv = left[lk], right[rk]
725
726 lhc = isinstance(left, HasCacheKey)
727 rhc = isinstance(right, HasCacheKey)
728 if lhc and rhc:
729 if lv._gen_cache_key(
730 self.anon_map[0], []
731 ) != rv._gen_cache_key(self.anon_map[1], []):
732 return COMPARE_FAILED
733 elif lhc != rhc:
734 return COMPARE_FAILED
735 elif lv != rv:
736 return COMPARE_FAILED
737
738 def visit_multi(
739 self, attrname, left_parent, left, right_parent, right, **kw
740 ):
741 lhc = isinstance(left, HasCacheKey)
742 rhc = isinstance(right, HasCacheKey)
743 if lhc and rhc:
744 if left._gen_cache_key(
745 self.anon_map[0], []
746 ) != right._gen_cache_key(self.anon_map[1], []):
747 return COMPARE_FAILED
748 elif lhc != rhc:
749 return COMPARE_FAILED
750 else:
751 return left == right
752
753 def visit_anon_name(
754 self, attrname, left_parent, left, right_parent, right, **kw
755 ):
756 return _resolve_name_for_compare(
757 left_parent, left, self.anon_map[0], **kw
758 ) == _resolve_name_for_compare(
759 right_parent, right, self.anon_map[1], **kw
760 )
761
762 def visit_boolean(
763 self, attrname, left_parent, left, right_parent, right, **kw
764 ):
765 return left == right
766
767 def visit_operator(
768 self, attrname, left_parent, left, right_parent, right, **kw
769 ):
770 return left == right
771
772 def visit_type(
773 self, attrname, left_parent, left, right_parent, right, **kw
774 ):
775 return left._compare_type_affinity(right)
776
777 def visit_plain_dict(
778 self, attrname, left_parent, left, right_parent, right, **kw
779 ):
780 return left == right
781
782 def visit_dialect_options(
783 self, attrname, left_parent, left, right_parent, right, **kw
784 ):
785 return left == right
786
787 def visit_annotations_key(
788 self, attrname, left_parent, left, right_parent, right, **kw
789 ):
790 if left and right:
791 return (
792 left_parent._annotations_cache_key
793 == right_parent._annotations_cache_key
794 )
795 else:
796 return left == right
797
798 def visit_with_context_options(
799 self, attrname, left_parent, left, right_parent, right, **kw
800 ):
801 return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
802 (fn.__code__, c_key) for fn, c_key in right
803 )
804
805 def visit_plain_obj(
806 self, attrname, left_parent, left, right_parent, right, **kw
807 ):
808 return left == right
809
810 def visit_named_ddl_element(
811 self, attrname, left_parent, left, right_parent, right, **kw
812 ):
813 if left is None:
814 if right is not None:
815 return COMPARE_FAILED
816
817 return left.name == right.name
818
819 def visit_prefix_sequence(
820 self, attrname, left_parent, left, right_parent, right, **kw
821 ):
822 for (l_clause, l_str), (r_clause, r_str) in zip_longest(
823 left, right, fillvalue=(None, None)
824 ):
825 if l_str != r_str:
826 return COMPARE_FAILED
827 else:
828 self.stack.append((l_clause, r_clause))
829
830 def visit_setup_join_tuple(
831 self, attrname, left_parent, left, right_parent, right, **kw
832 ):
833 # TODO: look at attrname for "legacy_join" and use different structure
834 for (
835 (l_target, l_onclause, l_from, l_flags),
836 (r_target, r_onclause, r_from, r_flags),
837 ) in zip_longest(left, right, fillvalue=(None, None, None, None)):
838 if l_flags != r_flags:
839 return COMPARE_FAILED
840 self.stack.append((l_target, r_target))
841 self.stack.append((l_onclause, r_onclause))
842 self.stack.append((l_from, r_from))
843
844 def visit_memoized_select_entities(
845 self, attrname, left_parent, left, right_parent, right, **kw
846 ):
847 return self.visit_clauseelement_tuple(
848 attrname, left_parent, left, right_parent, right, **kw
849 )
850
851 def visit_table_hint_list(
852 self, attrname, left_parent, left, right_parent, right, **kw
853 ):
854 left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
855 right_keys = sorted(
856 right, key=lambda elem: (elem[0].fullname, elem[1])
857 )
858 for (ltable, ldialect), (rtable, rdialect) in zip_longest(
859 left_keys, right_keys, fillvalue=(None, None)
860 ):
861 if ldialect != rdialect:
862 return COMPARE_FAILED
863 elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
864 return COMPARE_FAILED
865 else:
866 self.stack.append((ltable, rtable))
867
868 def visit_statement_hint_list(
869 self, attrname, left_parent, left, right_parent, right, **kw
870 ):
871 return left == right
872
873 def visit_unknown_structure(
874 self, attrname, left_parent, left, right_parent, right, **kw
875 ):
876 raise NotImplementedError()
877
878 def visit_dml_ordered_values(
879 self, attrname, left_parent, left, right_parent, right, **kw
880 ):
881 # sequence of tuple pairs
882
883 for (lk, lv), (rk, rv) in zip_longest(
884 left, right, fillvalue=(None, None)
885 ):
886 if not self._compare_dml_values_or_ce(lk, rk, **kw):
887 return COMPARE_FAILED
888
889 def _compare_dml_values_or_ce(self, lv, rv, **kw):
890 lvce = hasattr(lv, "__clause_element__")
891 rvce = hasattr(rv, "__clause_element__")
892 if lvce != rvce:
893 return False
894 elif lvce and not self.compare_inner(lv, rv, **kw):
895 return False
896 elif not lvce and lv != rv:
897 return False
898 elif not self.compare_inner(lv, rv, **kw):
899 return False
900
901 return True
902
903 def visit_dml_values(
904 self, attrname, left_parent, left, right_parent, right, **kw
905 ):
906 if left is None or right is None or len(left) != len(right):
907 return COMPARE_FAILED
908
909 if isinstance(left, collections_abc.Sequence):
910 for lv, rv in zip(left, right):
911 if not self._compare_dml_values_or_ce(lv, rv, **kw):
912 return COMPARE_FAILED
913 elif isinstance(right, collections_abc.Sequence):
914 return COMPARE_FAILED
915 else:
916 # dictionaries guaranteed to support insert ordering in
917 # py37 so that we can compare the keys in order. without
918 # this, we can't compare SQL expression keys because we don't
919 # know which key is which
920 for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
921 if not self._compare_dml_values_or_ce(lk, rk, **kw):
922 return COMPARE_FAILED
923 if not self._compare_dml_values_or_ce(lv, rv, **kw):
924 return COMPARE_FAILED
925
926 def visit_dml_multi_values(
927 self, attrname, left_parent, left, right_parent, right, **kw
928 ):
929 for lseq, rseq in zip_longest(left, right, fillvalue=None):
930 if lseq is None or rseq is None:
931 return COMPARE_FAILED
932
933 for ld, rd in zip_longest(lseq, rseq, fillvalue=None):
934 if (
935 self.visit_dml_values(
936 attrname, left_parent, ld, right_parent, rd, **kw
937 )
938 is COMPARE_FAILED
939 ):
940 return COMPARE_FAILED
941
942 def compare_expression_clauselist(self, left, right, **kw):
943 if left.operator is right.operator:
944 if operators.is_associative(left.operator):
945 if self._compare_unordered_sequences(
946 left.clauses, right.clauses, **kw
947 ):
948 return ["operator", "clauses"]
949 else:
950 return COMPARE_FAILED
951 else:
952 return ["operator"]
953 else:
954 return COMPARE_FAILED
955
956 def compare_clauselist(self, left, right, **kw):
957 return self.compare_expression_clauselist(left, right, **kw)
958
959 def compare_binary(self, left, right, **kw):
960 if left.operator == right.operator:
961 if operators.is_commutative(left.operator):
962 if (
963 self.compare_inner(left.left, right.left, **kw)
964 and self.compare_inner(left.right, right.right, **kw)
965 ) or (
966 self.compare_inner(left.left, right.right, **kw)
967 and self.compare_inner(left.right, right.left, **kw)
968 ):
969 return ["operator", "negate", "left", "right"]
970 else:
971 return COMPARE_FAILED
972 else:
973 return ["operator", "negate"]
974 else:
975 return COMPARE_FAILED
976
977 def compare_bindparam(self, left, right, **kw):
978 compare_keys = kw.pop("compare_keys", True)
979 compare_values = kw.pop("compare_values", True)
980
981 if compare_values:
982 omit = []
983 else:
984 # this means, "skip these, we already compared"
985 omit = ["callable", "value"]
986
987 if not compare_keys:
988 omit.append("key")
989
990 return omit
991
992
993class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
994 def compare_column_element(
995 self, left, right, use_proxies=True, equivalents=(), **kw
996 ):
997 """Compare ColumnElements using proxies and equivalent collections.
998
999 This is a comparison strategy specific to the ORM.
1000 """
1001
1002 to_compare = (right,)
1003 if equivalents and right in equivalents:
1004 to_compare = equivalents[right].union(to_compare)
1005
1006 for oth in to_compare:
1007 if use_proxies and left.shares_lineage(oth):
1008 return SKIP_TRAVERSE
1009 elif hash(left) == hash(right):
1010 return SKIP_TRAVERSE
1011 else:
1012 return COMPARE_FAILED
1013
1014 def compare_column(self, left, right, **kw):
1015 return self.compare_column_element(left, right, **kw)
1016
1017 def compare_label(self, left, right, **kw):
1018 return self.compare_column_element(left, right, **kw)
1019
1020 def compare_table(self, left, right, **kw):
1021 # tables compare on identity, since it's not really feasible to
1022 # compare them column by column with the above rules
1023 return SKIP_TRAVERSE if left is right else COMPARE_FAILED