1# sql/traversals.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11from collections import deque
12import collections.abc as collections_abc
13import itertools
14from itertools import zip_longest
15import operator
16import typing
17from typing import Any
18from typing import Callable
19from typing import Deque
20from typing import Dict
21from typing import Iterable
22from typing import Optional
23from typing import Set
24from typing import Tuple
25from typing import Type
26
27from . import operators
28from .cache_key import HasCacheKey
29from .visitors import _TraverseInternalsType
30from .visitors import anon_map
31from .visitors import ExternallyTraversible
32from .visitors import HasTraversalDispatch
33from .visitors import HasTraverseInternals
34from .. import util
35from ..util import langhelpers
36from ..util.typing import Self
37
38SKIP_TRAVERSE = util.symbol("skip_traverse")
39COMPARE_FAILED = False
40COMPARE_SUCCEEDED = True
41
42
43def compare(obj1: Any, obj2: Any, **kw: Any) -> bool:
44 strategy: TraversalComparatorStrategy
45 if kw.get("use_proxies", False):
46 strategy = ColIdentityComparatorStrategy()
47 else:
48 strategy = TraversalComparatorStrategy()
49
50 return strategy.compare(obj1, obj2, **kw)
51
52
53def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None:
54 for cls in util.walk_subclasses(target_hierarchy):
55 if hasattr(cls, "_generate_cache_attrs") and hasattr(
56 cls, "_traverse_internals"
57 ):
58 cls._generate_cache_attrs()
59 _copy_internals.generate_dispatch(
60 cls,
61 cls._traverse_internals,
62 "_generated_copy_internals_traversal",
63 )
64 _get_children.generate_dispatch(
65 cls,
66 cls._traverse_internals,
67 "_generated_get_children_traversal",
68 )
69
70
71class HasShallowCopy(HasTraverseInternals):
72 """attribute-wide operations that are useful for classes that use
73 __slots__ and therefore can't operate on their attributes in a dictionary.
74
75
76 """
77
78 __slots__ = ()
79
80 if typing.TYPE_CHECKING:
81
82 def _generated_shallow_copy_traversal(self, other: Self) -> None: ...
83
84 def _generated_shallow_from_dict_traversal(
85 self, d: Dict[str, Any]
86 ) -> None: ...
87
88 def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: ...
89
90 @classmethod
91 def _generate_shallow_copy(
92 cls,
93 internal_dispatch: _TraverseInternalsType,
94 method_name: str,
95 ) -> Callable[[Self, Self], None]:
96 code = "\n".join(
97 f" other.{attrname} = self.{attrname}"
98 for attrname, _ in internal_dispatch
99 )
100 meth_text = f"def {method_name}(self, other):\n{code}\n"
101 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
102
103 @classmethod
104 def _generate_shallow_to_dict(
105 cls,
106 internal_dispatch: _TraverseInternalsType,
107 method_name: str,
108 ) -> Callable[[Self], Dict[str, Any]]:
109 code = ",\n".join(
110 f" '{attrname}': self.{attrname}"
111 for attrname, _ in internal_dispatch
112 )
113 meth_text = f"def {method_name}(self):\n return {{{code}}}\n"
114 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
115
116 @classmethod
117 def _generate_shallow_from_dict(
118 cls,
119 internal_dispatch: _TraverseInternalsType,
120 method_name: str,
121 ) -> Callable[[Self, Dict[str, Any]], None]:
122 code = "\n".join(
123 f" self.{attrname} = d['{attrname}']"
124 for attrname, _ in internal_dispatch
125 )
126 meth_text = f"def {method_name}(self, d):\n{code}\n"
127 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
128
129 def _shallow_from_dict(self, d: Dict[str, Any]) -> None:
130 cls = self.__class__
131
132 shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None]
133 try:
134 shallow_from_dict = cls.__dict__[
135 "_generated_shallow_from_dict_traversal"
136 ]
137 except KeyError:
138 shallow_from_dict = self._generate_shallow_from_dict(
139 cls._traverse_internals,
140 "_generated_shallow_from_dict_traversal",
141 )
142
143 cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa: E501
144
145 shallow_from_dict(self, d)
146
147 def _shallow_to_dict(self) -> Dict[str, Any]:
148 cls = self.__class__
149
150 shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]]
151
152 try:
153 shallow_to_dict = cls.__dict__[
154 "_generated_shallow_to_dict_traversal"
155 ]
156 except KeyError:
157 shallow_to_dict = self._generate_shallow_to_dict(
158 cls._traverse_internals, "_generated_shallow_to_dict_traversal"
159 )
160
161 cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa: E501
162 return shallow_to_dict(self)
163
164 def _shallow_copy_to(self, other: Self) -> None:
165 cls = self.__class__
166
167 shallow_copy: Callable[[Self, Self], None]
168 try:
169 shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
170 except KeyError:
171 shallow_copy = self._generate_shallow_copy(
172 cls._traverse_internals, "_generated_shallow_copy_traversal"
173 )
174
175 cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501
176 shallow_copy(self, other)
177
178 def _clone(self, **kw: Any) -> Self:
179 """Create a shallow copy"""
180 c = self.__class__.__new__(self.__class__)
181 self._shallow_copy_to(c)
182 return c
183
184
185class GenerativeOnTraversal(HasShallowCopy):
186 """Supplies Generative behavior but making use of traversals to shallow
187 copy.
188
189 .. seealso::
190
191 :class:`sqlalchemy.sql.base.Generative`
192
193
194 """
195
196 __slots__ = ()
197
198 def _generate(self) -> Self:
199 cls = self.__class__
200 s = cls.__new__(cls)
201 self._shallow_copy_to(s)
202 return s
203
204
205def _clone(element, **kw):
206 return element._clone()
207
208
209class HasCopyInternals(HasTraverseInternals):
210 __slots__ = ()
211
212 def _clone(self, **kw):
213 raise NotImplementedError()
214
215 def _copy_internals(
216 self, *, omit_attrs: Iterable[str] = (), **kw: Any
217 ) -> None:
218 """Reassign internal elements to be clones of themselves.
219
220 Called during a copy-and-traverse operation on newly
221 shallow-copied elements to create a deep copy.
222
223 The given clone function should be used, which may be applying
224 additional transformations to the element (i.e. replacement
225 traversal, cloned traversal, annotations).
226
227 """
228
229 try:
230 traverse_internals = self._traverse_internals
231 except AttributeError:
232 # user-defined classes may not have a _traverse_internals
233 return
234
235 for attrname, obj, meth in _copy_internals.run_generated_dispatch(
236 self, traverse_internals, "_generated_copy_internals_traversal"
237 ):
238 if attrname in omit_attrs:
239 continue
240
241 if obj is not None:
242 result = meth(attrname, self, obj, **kw)
243 if result is not None:
244 setattr(self, attrname, result)
245
246
247class _CopyInternalsTraversal(HasTraversalDispatch):
248 """Generate a _copy_internals internal traversal dispatch for classes
249 with a _traverse_internals collection."""
250
251 def visit_clauseelement(
252 self, attrname, parent, element, clone=_clone, **kw
253 ):
254 return clone(element, **kw)
255
256 def visit_clauseelement_list(
257 self, attrname, parent, element, clone=_clone, **kw
258 ):
259 return [clone(clause, **kw) for clause in element]
260
261 def visit_clauseelement_tuple(
262 self, attrname, parent, element, clone=_clone, **kw
263 ):
264 return tuple([clone(clause, **kw) for clause in element])
265
266 def visit_executable_options(
267 self, attrname, parent, element, clone=_clone, **kw
268 ):
269 return tuple([clone(clause, **kw) for clause in element])
270
271 def visit_clauseelement_unordered_set(
272 self, attrname, parent, element, clone=_clone, **kw
273 ):
274 return {clone(clause, **kw) for clause in element}
275
276 def visit_clauseelement_tuples(
277 self, attrname, parent, element, clone=_clone, **kw
278 ):
279 return [
280 tuple(clone(tup_elem, **kw) for tup_elem in elem)
281 for elem in element
282 ]
283
284 def visit_string_clauseelement_dict(
285 self, attrname, parent, element, clone=_clone, **kw
286 ):
287 return {key: clone(value, **kw) for key, value in element.items()}
288
289 def visit_setup_join_tuple(
290 self, attrname, parent, element, clone=_clone, **kw
291 ):
292 return tuple(
293 (
294 clone(target, **kw) if target is not None else None,
295 clone(onclause, **kw) if onclause is not None else None,
296 clone(from_, **kw) if from_ is not None else None,
297 flags,
298 )
299 for (target, onclause, from_, flags) in element
300 )
301
302 def visit_memoized_select_entities(self, attrname, parent, element, **kw):
303 return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
304
305 def visit_dml_ordered_values(
306 self, attrname, parent, element, clone=_clone, **kw
307 ):
308 # sequence of 2-tuples
309 return [
310 (
311 (
312 clone(key, **kw)
313 if hasattr(key, "__clause_element__")
314 else key
315 ),
316 clone(value, **kw),
317 )
318 for key, value in element
319 ]
320
321 def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
322 return {
323 (
324 clone(key, **kw) if hasattr(key, "__clause_element__") else key
325 ): clone(value, **kw)
326 for key, value in element.items()
327 }
328
329 def visit_dml_multi_values(
330 self, attrname, parent, element, clone=_clone, **kw
331 ):
332 # sequence of sequences, each sequence contains a list/dict/tuple
333
334 def copy(elem):
335 if isinstance(elem, (list, tuple)):
336 return [
337 (
338 clone(value, **kw)
339 if hasattr(value, "__clause_element__")
340 else value
341 )
342 for value in elem
343 ]
344 elif isinstance(elem, dict):
345 return {
346 (
347 clone(key, **kw)
348 if hasattr(key, "__clause_element__")
349 else key
350 ): (
351 clone(value, **kw)
352 if hasattr(value, "__clause_element__")
353 else value
354 )
355 for key, value in elem.items()
356 }
357 else:
358 # TODO: use abc classes
359 assert False
360
361 return [
362 [copy(sub_element) for sub_element in sequence]
363 for sequence in element
364 ]
365
366 def visit_propagate_attrs(
367 self, attrname, parent, element, clone=_clone, **kw
368 ):
369 return element
370
371
372_copy_internals = _CopyInternalsTraversal()
373
374
375def _flatten_clauseelement(element):
376 while hasattr(element, "__clause_element__") and not getattr(
377 element, "is_clause_element", False
378 ):
379 element = element.__clause_element__()
380
381 return element
382
383
384class _GetChildrenTraversal(HasTraversalDispatch):
385 """Generate a _children_traversal internal traversal dispatch for classes
386 with a _traverse_internals collection."""
387
388 def visit_has_cache_key(self, element, **kw):
389 # the GetChildren traversal refers explicitly to ClauseElement
390 # structures. Within these, a plain HasCacheKey is not a
391 # ClauseElement, so don't include these.
392 return ()
393
394 def visit_clauseelement(self, element, **kw):
395 return (element,)
396
397 def visit_clauseelement_list(self, element, **kw):
398 return element
399
400 def visit_clauseelement_tuple(self, element, **kw):
401 return element
402
403 def visit_clauseelement_tuples(self, element, **kw):
404 return itertools.chain.from_iterable(element)
405
406 def visit_fromclause_canonical_column_collection(self, element, **kw):
407 return ()
408
409 def visit_string_clauseelement_dict(self, element, **kw):
410 return element.values()
411
412 def visit_fromclause_ordered_set(self, element, **kw):
413 return element
414
415 def visit_clauseelement_unordered_set(self, element, **kw):
416 return element
417
418 def visit_setup_join_tuple(self, element, **kw):
419 for target, onclause, from_, flags in element:
420 if from_ is not None:
421 yield from_
422
423 if not isinstance(target, str):
424 yield _flatten_clauseelement(target)
425
426 if onclause is not None and not isinstance(onclause, str):
427 yield _flatten_clauseelement(onclause)
428
429 def visit_memoized_select_entities(self, element, **kw):
430 return self.visit_clauseelement_tuple(element, **kw)
431
432 def visit_dml_ordered_values(self, element, **kw):
433 for k, v in element:
434 if hasattr(k, "__clause_element__"):
435 yield k
436 yield v
437
438 def visit_dml_values(self, element, **kw):
439 expr_values = {k for k in element if hasattr(k, "__clause_element__")}
440 str_values = expr_values.symmetric_difference(element)
441
442 for k in sorted(str_values):
443 yield element[k]
444 for k in expr_values:
445 yield k
446 yield element[k]
447
448 def visit_dml_multi_values(self, element, **kw):
449 return ()
450
451 def visit_propagate_attrs(self, element, **kw):
452 return ()
453
454
455_get_children = _GetChildrenTraversal()
456
457
458@util.preload_module("sqlalchemy.sql.elements")
459def _resolve_name_for_compare(element, name, anon_map, **kw):
460 if isinstance(name, util.preloaded.sql_elements._anonymous_label):
461 name = name.apply_map(anon_map)
462
463 return name
464
465
466class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
467 __slots__ = "stack", "cache", "anon_map"
468
469 def __init__(self):
470 self.stack: Deque[
471 Tuple[
472 Optional[ExternallyTraversible],
473 Optional[ExternallyTraversible],
474 ]
475 ] = deque()
476 self.cache = set()
477
478 def _memoized_attr_anon_map(self):
479 return (anon_map(), anon_map())
480
481 def compare(
482 self,
483 obj1: ExternallyTraversible,
484 obj2: ExternallyTraversible,
485 **kw: Any,
486 ) -> bool:
487 stack = self.stack
488 cache = self.cache
489
490 compare_annotations = kw.get("compare_annotations", False)
491
492 stack.append((obj1, obj2))
493
494 while stack:
495 left, right = stack.popleft()
496
497 if left is right:
498 continue
499 elif left is None or right is None:
500 # we know they are different so no match
501 return False
502 elif (left, right) in cache:
503 continue
504 cache.add((left, right))
505
506 visit_name = left.__visit_name__
507 if visit_name != right.__visit_name__:
508 return False
509
510 meth = getattr(self, "compare_%s" % visit_name, None)
511
512 if meth:
513 attributes_compared = meth(left, right, **kw)
514 if attributes_compared is COMPARE_FAILED:
515 return False
516 elif attributes_compared is SKIP_TRAVERSE:
517 continue
518
519 # attributes_compared is returned as a list of attribute
520 # names that were "handled" by the comparison method above.
521 # remaining attribute names in the _traverse_internals
522 # will be compared.
523 else:
524 attributes_compared = ()
525
526 for (
527 (left_attrname, left_visit_sym),
528 (right_attrname, right_visit_sym),
529 ) in zip_longest(
530 left._traverse_internals,
531 right._traverse_internals,
532 fillvalue=(None, None),
533 ):
534 if not compare_annotations and (
535 (left_attrname == "_annotations")
536 or (right_attrname == "_annotations")
537 ):
538 continue
539
540 if (
541 left_attrname != right_attrname
542 or left_visit_sym is not right_visit_sym
543 ):
544 return False
545 elif left_attrname in attributes_compared:
546 continue
547
548 assert left_visit_sym is not None
549 assert left_attrname is not None
550 assert right_attrname is not None
551
552 dispatch = self.dispatch(left_visit_sym)
553 assert dispatch is not None, (
554 f"{self.__class__} has no dispatch for "
555 f"'{self._dispatch_lookup[left_visit_sym]}'"
556 )
557 left_child = operator.attrgetter(left_attrname)(left)
558 right_child = operator.attrgetter(right_attrname)(right)
559 if left_child is None:
560 if right_child is not None:
561 return False
562 else:
563 continue
564 elif right_child is None:
565 return False
566
567 comparison = dispatch(
568 left_attrname, left, left_child, right, right_child, **kw
569 )
570 if comparison is COMPARE_FAILED:
571 return False
572
573 return True
574
575 def compare_inner(self, obj1, obj2, **kw):
576 comparator = self.__class__()
577 return comparator.compare(obj1, obj2, **kw)
578
579 def visit_has_cache_key(
580 self, attrname, left_parent, left, right_parent, right, **kw
581 ):
582 if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
583 self.anon_map[1], []
584 ):
585 return COMPARE_FAILED
586
587 def visit_propagate_attrs(
588 self, attrname, left_parent, left, right_parent, right, **kw
589 ):
590 return self.compare_inner(
591 left.get("plugin_subject", None), right.get("plugin_subject", None)
592 )
593
594 def visit_has_cache_key_list(
595 self, attrname, left_parent, left, right_parent, right, **kw
596 ):
597 for l, r in zip_longest(left, right, fillvalue=None):
598 if l is None:
599 if r is not None:
600 return COMPARE_FAILED
601 else:
602 continue
603 elif r is None:
604 return COMPARE_FAILED
605
606 if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
607 self.anon_map[1], []
608 ):
609 return COMPARE_FAILED
610
611 def visit_executable_options(
612 self, attrname, left_parent, left, right_parent, right, **kw
613 ):
614 for l, r in zip_longest(left, right, fillvalue=None):
615 if l is None:
616 if r is not None:
617 return COMPARE_FAILED
618 else:
619 continue
620 elif r is None:
621 return COMPARE_FAILED
622
623 if (
624 l._gen_cache_key(self.anon_map[0], [])
625 if l._is_has_cache_key
626 else l
627 ) != (
628 r._gen_cache_key(self.anon_map[1], [])
629 if r._is_has_cache_key
630 else r
631 ):
632 return COMPARE_FAILED
633
634 def visit_clauseelement(
635 self, attrname, left_parent, left, right_parent, right, **kw
636 ):
637 self.stack.append((left, right))
638
639 def visit_fromclause_canonical_column_collection(
640 self, attrname, left_parent, left, right_parent, right, **kw
641 ):
642 for lcol, rcol in zip_longest(left, right, fillvalue=None):
643 self.stack.append((lcol, rcol))
644
645 def visit_fromclause_derived_column_collection(
646 self, attrname, left_parent, left, right_parent, right, **kw
647 ):
648 pass
649
650 def visit_string_clauseelement_dict(
651 self, attrname, left_parent, left, right_parent, right, **kw
652 ):
653 for lstr, rstr in zip_longest(
654 sorted(left), sorted(right), fillvalue=None
655 ):
656 if lstr != rstr:
657 return COMPARE_FAILED
658 self.stack.append((left[lstr], right[rstr]))
659
660 def visit_clauseelement_tuples(
661 self, attrname, left_parent, left, right_parent, right, **kw
662 ):
663 for ltup, rtup in zip_longest(left, right, fillvalue=None):
664 if ltup is None or rtup is None:
665 return COMPARE_FAILED
666
667 for l, r in zip_longest(ltup, rtup, fillvalue=None):
668 self.stack.append((l, r))
669
670 def visit_multi_list(
671 self, attrname, left_parent, left, right_parent, right, **kw
672 ):
673 for l, r in zip_longest(left, right, fillvalue=None):
674 if isinstance(l, str):
675 if not isinstance(r, str) or l != r:
676 return COMPARE_FAILED
677 elif isinstance(r, str):
678 if not isinstance(l, str) or l != r:
679 return COMPARE_FAILED
680 else:
681 self.stack.append((l, r))
682
683 def visit_clauseelement_list(
684 self, attrname, left_parent, left, right_parent, right, **kw
685 ):
686 for l, r in zip_longest(left, right, fillvalue=None):
687 self.stack.append((l, r))
688
689 def visit_clauseelement_tuple(
690 self, attrname, left_parent, left, right_parent, right, **kw
691 ):
692 for l, r in zip_longest(left, right, fillvalue=None):
693 self.stack.append((l, r))
694
695 def _compare_unordered_sequences(self, seq1, seq2, **kw):
696 if seq1 is None:
697 return seq2 is None
698
699 completed: Set[object] = set()
700 for clause in seq1:
701 for other_clause in set(seq2).difference(completed):
702 if self.compare_inner(clause, other_clause, **kw):
703 completed.add(other_clause)
704 break
705 return len(completed) == len(seq1) == len(seq2)
706
707 def visit_clauseelement_unordered_set(
708 self, attrname, left_parent, left, right_parent, right, **kw
709 ):
710 return self._compare_unordered_sequences(left, right, **kw)
711
712 def visit_fromclause_ordered_set(
713 self, attrname, left_parent, left, right_parent, right, **kw
714 ):
715 for l, r in zip_longest(left, right, fillvalue=None):
716 self.stack.append((l, r))
717
718 def visit_string(
719 self, attrname, left_parent, left, right_parent, right, **kw
720 ):
721 return left == right
722
723 def visit_string_list(
724 self, attrname, left_parent, left, right_parent, right, **kw
725 ):
726 return left == right
727
728 def visit_string_multi_dict(
729 self, attrname, left_parent, left, right_parent, right, **kw
730 ):
731 for lk, rk in zip_longest(
732 sorted(left.keys()), sorted(right.keys()), fillvalue=(None, None)
733 ):
734 if lk != rk:
735 return COMPARE_FAILED
736
737 lv, rv = left[lk], right[rk]
738
739 lhc = isinstance(left, HasCacheKey)
740 rhc = isinstance(right, HasCacheKey)
741 if lhc and rhc:
742 if lv._gen_cache_key(
743 self.anon_map[0], []
744 ) != rv._gen_cache_key(self.anon_map[1], []):
745 return COMPARE_FAILED
746 elif lhc != rhc:
747 return COMPARE_FAILED
748 elif lv != rv:
749 return COMPARE_FAILED
750
751 def visit_multi(
752 self, attrname, left_parent, left, right_parent, right, **kw
753 ):
754 lhc = isinstance(left, HasCacheKey)
755 rhc = isinstance(right, HasCacheKey)
756 if lhc and rhc:
757 if left._gen_cache_key(
758 self.anon_map[0], []
759 ) != right._gen_cache_key(self.anon_map[1], []):
760 return COMPARE_FAILED
761 elif lhc != rhc:
762 return COMPARE_FAILED
763 else:
764 return left == right
765
766 def visit_anon_name(
767 self, attrname, left_parent, left, right_parent, right, **kw
768 ):
769 return _resolve_name_for_compare(
770 left_parent, left, self.anon_map[0], **kw
771 ) == _resolve_name_for_compare(
772 right_parent, right, self.anon_map[1], **kw
773 )
774
775 def visit_boolean(
776 self, attrname, left_parent, left, right_parent, right, **kw
777 ):
778 return left == right
779
780 def visit_operator(
781 self, attrname, left_parent, left, right_parent, right, **kw
782 ):
783 return left == right
784
785 def visit_type(
786 self, attrname, left_parent, left, right_parent, right, **kw
787 ):
788 return left._compare_type_affinity(right)
789
790 def visit_plain_dict(
791 self, attrname, left_parent, left, right_parent, right, **kw
792 ):
793 return left == right
794
795 def visit_dialect_options(
796 self, attrname, left_parent, left, right_parent, right, **kw
797 ):
798 return left == right
799
800 def visit_annotations_key(
801 self, attrname, left_parent, left, right_parent, right, **kw
802 ):
803 if left and right:
804 return (
805 left_parent._annotations_cache_key
806 == right_parent._annotations_cache_key
807 )
808 else:
809 return left == right
810
811 def visit_compile_state_funcs(
812 self, attrname, left_parent, left, right_parent, right, **kw
813 ):
814 return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
815 (fn.__code__, c_key) for fn, c_key in right
816 )
817
818 def visit_plain_obj(
819 self, attrname, left_parent, left, right_parent, right, **kw
820 ):
821 return left == right
822
823 def visit_named_ddl_element(
824 self, attrname, left_parent, left, right_parent, right, **kw
825 ):
826 if left is None:
827 if right is not None:
828 return COMPARE_FAILED
829
830 return left.name == right.name
831
832 def visit_prefix_sequence(
833 self, attrname, left_parent, left, right_parent, right, **kw
834 ):
835 for (l_clause, l_str), (r_clause, r_str) in zip_longest(
836 left, right, fillvalue=(None, None)
837 ):
838 if l_str != r_str:
839 return COMPARE_FAILED
840 else:
841 self.stack.append((l_clause, r_clause))
842
843 def visit_setup_join_tuple(
844 self, attrname, left_parent, left, right_parent, right, **kw
845 ):
846 # TODO: look at attrname for "legacy_join" and use different structure
847 for (
848 (l_target, l_onclause, l_from, l_flags),
849 (r_target, r_onclause, r_from, r_flags),
850 ) in zip_longest(left, right, fillvalue=(None, None, None, None)):
851 if l_flags != r_flags:
852 return COMPARE_FAILED
853 self.stack.append((l_target, r_target))
854 self.stack.append((l_onclause, r_onclause))
855 self.stack.append((l_from, r_from))
856
857 def visit_memoized_select_entities(
858 self, attrname, left_parent, left, right_parent, right, **kw
859 ):
860 return self.visit_clauseelement_tuple(
861 attrname, left_parent, left, right_parent, right, **kw
862 )
863
864 def visit_table_hint_list(
865 self, attrname, left_parent, left, right_parent, right, **kw
866 ):
867 left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
868 right_keys = sorted(
869 right, key=lambda elem: (elem[0].fullname, elem[1])
870 )
871 for (ltable, ldialect), (rtable, rdialect) in zip_longest(
872 left_keys, right_keys, fillvalue=(None, None)
873 ):
874 if ldialect != rdialect:
875 return COMPARE_FAILED
876 elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
877 return COMPARE_FAILED
878 else:
879 self.stack.append((ltable, rtable))
880
881 def visit_statement_hint_list(
882 self, attrname, left_parent, left, right_parent, right, **kw
883 ):
884 return left == right
885
886 def visit_unknown_structure(
887 self, attrname, left_parent, left, right_parent, right, **kw
888 ):
889 raise NotImplementedError()
890
891 def visit_dml_ordered_values(
892 self, attrname, left_parent, left, right_parent, right, **kw
893 ):
894 # sequence of tuple pairs
895
896 for (lk, lv), (rk, rv) in zip_longest(
897 left, right, fillvalue=(None, None)
898 ):
899 if not self._compare_dml_values_or_ce(lk, rk, **kw):
900 return COMPARE_FAILED
901
902 def _compare_dml_values_or_ce(self, lv, rv, **kw):
903 lvce = hasattr(lv, "__clause_element__")
904 rvce = hasattr(rv, "__clause_element__")
905 if lvce != rvce:
906 return False
907 elif lvce and not self.compare_inner(lv, rv, **kw):
908 return False
909 elif not lvce and lv != rv:
910 return False
911 elif not self.compare_inner(lv, rv, **kw):
912 return False
913
914 return True
915
916 def visit_dml_values(
917 self, attrname, left_parent, left, right_parent, right, **kw
918 ):
919 if left is None or right is None or len(left) != len(right):
920 return COMPARE_FAILED
921
922 if isinstance(left, collections_abc.Sequence):
923 for lv, rv in zip(left, right):
924 if not self._compare_dml_values_or_ce(lv, rv, **kw):
925 return COMPARE_FAILED
926 elif isinstance(right, collections_abc.Sequence):
927 return COMPARE_FAILED
928 else:
929 # dictionaries guaranteed to support insert ordering in
930 # py37 so that we can compare the keys in order. without
931 # this, we can't compare SQL expression keys because we don't
932 # know which key is which
933 for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
934 if not self._compare_dml_values_or_ce(lk, rk, **kw):
935 return COMPARE_FAILED
936 if not self._compare_dml_values_or_ce(lv, rv, **kw):
937 return COMPARE_FAILED
938
939 def visit_dml_multi_values(
940 self, attrname, left_parent, left, right_parent, right, **kw
941 ):
942 for lseq, rseq in zip_longest(left, right, fillvalue=None):
943 if lseq is None or rseq is None:
944 return COMPARE_FAILED
945
946 for ld, rd in zip_longest(lseq, rseq, fillvalue=None):
947 if (
948 self.visit_dml_values(
949 attrname, left_parent, ld, right_parent, rd, **kw
950 )
951 is COMPARE_FAILED
952 ):
953 return COMPARE_FAILED
954
955 def visit_params(
956 self, attrname, left_parent, left, right_parent, right, **kw
957 ):
958 return left == right
959
960 def compare_expression_clauselist(self, left, right, **kw):
961 if left.operator is right.operator:
962 if operators.is_associative(left.operator):
963 if self._compare_unordered_sequences(
964 left.clauses, right.clauses, **kw
965 ):
966 return ["operator", "clauses"]
967 else:
968 return COMPARE_FAILED
969 else:
970 return ["operator"]
971 else:
972 return COMPARE_FAILED
973
974 def compare_clauselist(self, left, right, **kw):
975 return self.compare_expression_clauselist(left, right, **kw)
976
977 def compare_binary(self, left, right, **kw):
978 if left.operator == right.operator:
979 if operators.is_commutative(left.operator):
980 if (
981 self.compare_inner(left.left, right.left, **kw)
982 and self.compare_inner(left.right, right.right, **kw)
983 ) or (
984 self.compare_inner(left.left, right.right, **kw)
985 and self.compare_inner(left.right, right.left, **kw)
986 ):
987 return ["operator", "negate", "left", "right"]
988 else:
989 return COMPARE_FAILED
990 else:
991 return ["operator", "negate"]
992 else:
993 return COMPARE_FAILED
994
995 def compare_bindparam(self, left, right, **kw):
996 compare_keys = kw.pop("compare_keys", True)
997 compare_values = kw.pop("compare_values", True)
998
999 if compare_values:
1000 omit = []
1001 else:
1002 # this means, "skip these, we already compared"
1003 omit = ["callable", "value"]
1004
1005 if not compare_keys:
1006 omit.append("key")
1007
1008 return omit
1009
1010
1011class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
1012 def compare_column_element(
1013 self, left, right, use_proxies=True, equivalents=(), **kw
1014 ):
1015 """Compare ColumnElements using proxies and equivalent collections.
1016
1017 This is a comparison strategy specific to the ORM.
1018 """
1019
1020 to_compare = (right,)
1021 if equivalents and right in equivalents:
1022 to_compare = equivalents[right].union(to_compare)
1023
1024 for oth in to_compare:
1025 if use_proxies and left.shares_lineage(oth):
1026 return SKIP_TRAVERSE
1027 elif hash(left) == hash(right):
1028 return SKIP_TRAVERSE
1029 else:
1030 return COMPARE_FAILED
1031
1032 def compare_column(self, left, right, **kw):
1033 return self.compare_column_element(left, right, **kw)
1034
1035 def compare_label(self, left, right, **kw):
1036 return self.compare_column_element(left, right, **kw)
1037
1038 def compare_table(self, left, right, **kw):
1039 # tables compare on identity, since it's not really feasible to
1040 # compare them column by column with the above rules
1041 return SKIP_TRAVERSE if left is right else COMPARE_FAILED