1# sql/traversals.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9from __future__ import annotations
10
11from collections import deque
12import collections.abc as collections_abc
13import itertools
14from itertools import zip_longest
15import operator
16import typing
17from typing import Any
18from typing import Callable
19from typing import Deque
20from typing import Dict
21from typing import Iterable
22from typing import Optional
23from typing import Set
24from typing import Tuple
25from typing import Type
26
27from . import operators
28from .cache_key import HasCacheKey
29from .visitors import _TraverseInternalsType
30from .visitors import anon_map
31from .visitors import ExternallyTraversible
32from .visitors import HasTraversalDispatch
33from .visitors import HasTraverseInternals
34from .. import util
35from ..util import langhelpers
36from ..util.typing import Self
37
38
39SKIP_TRAVERSE = util.symbol("skip_traverse")
40COMPARE_FAILED = False
41COMPARE_SUCCEEDED = True
42
43
44def compare(obj1: Any, obj2: Any, **kw: Any) -> bool:
45 strategy: TraversalComparatorStrategy
46 if kw.get("use_proxies", False):
47 strategy = ColIdentityComparatorStrategy()
48 else:
49 strategy = TraversalComparatorStrategy()
50
51 return strategy.compare(obj1, obj2, **kw)
52
53
54def _preconfigure_traversals(target_hierarchy: Type[Any]) -> None:
55 for cls in util.walk_subclasses(target_hierarchy):
56 if hasattr(cls, "_generate_cache_attrs") and hasattr(
57 cls, "_traverse_internals"
58 ):
59 cls._generate_cache_attrs()
60 _copy_internals.generate_dispatch(
61 cls,
62 cls._traverse_internals,
63 "_generated_copy_internals_traversal",
64 )
65 _get_children.generate_dispatch(
66 cls,
67 cls._traverse_internals,
68 "_generated_get_children_traversal",
69 )
70
71
72class HasShallowCopy(HasTraverseInternals):
73 """attribute-wide operations that are useful for classes that use
74 __slots__ and therefore can't operate on their attributes in a dictionary.
75
76
77 """
78
79 __slots__ = ()
80
81 if typing.TYPE_CHECKING:
82
83 def _generated_shallow_copy_traversal(self, other: Self) -> None: ...
84
85 def _generated_shallow_from_dict_traversal(
86 self, d: Dict[str, Any]
87 ) -> None: ...
88
89 def _generated_shallow_to_dict_traversal(self) -> Dict[str, Any]: ...
90
91 @classmethod
92 def _generate_shallow_copy(
93 cls,
94 internal_dispatch: _TraverseInternalsType,
95 method_name: str,
96 ) -> Callable[[Self, Self], None]:
97 code = "\n".join(
98 f" other.{attrname} = self.{attrname}"
99 for attrname, _ in internal_dispatch
100 )
101 meth_text = f"def {method_name}(self, other):\n{code}\n"
102 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
103
104 @classmethod
105 def _generate_shallow_to_dict(
106 cls,
107 internal_dispatch: _TraverseInternalsType,
108 method_name: str,
109 ) -> Callable[[Self], Dict[str, Any]]:
110 code = ",\n".join(
111 f" '{attrname}': self.{attrname}"
112 for attrname, _ in internal_dispatch
113 )
114 meth_text = f"def {method_name}(self):\n return {{{code}}}\n"
115 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
116
117 @classmethod
118 def _generate_shallow_from_dict(
119 cls,
120 internal_dispatch: _TraverseInternalsType,
121 method_name: str,
122 ) -> Callable[[Self, Dict[str, Any]], None]:
123 code = "\n".join(
124 f" self.{attrname} = d['{attrname}']"
125 for attrname, _ in internal_dispatch
126 )
127 meth_text = f"def {method_name}(self, d):\n{code}\n"
128 return langhelpers._exec_code_in_env(meth_text, {}, method_name)
129
130 def _shallow_from_dict(self, d: Dict[str, Any]) -> None:
131 cls = self.__class__
132
133 shallow_from_dict: Callable[[HasShallowCopy, Dict[str, Any]], None]
134 try:
135 shallow_from_dict = cls.__dict__[
136 "_generated_shallow_from_dict_traversal"
137 ]
138 except KeyError:
139 shallow_from_dict = self._generate_shallow_from_dict(
140 cls._traverse_internals,
141 "_generated_shallow_from_dict_traversal",
142 )
143
144 cls._generated_shallow_from_dict_traversal = shallow_from_dict # type: ignore # noqa: E501
145
146 shallow_from_dict(self, d)
147
148 def _shallow_to_dict(self) -> Dict[str, Any]:
149 cls = self.__class__
150
151 shallow_to_dict: Callable[[HasShallowCopy], Dict[str, Any]]
152
153 try:
154 shallow_to_dict = cls.__dict__[
155 "_generated_shallow_to_dict_traversal"
156 ]
157 except KeyError:
158 shallow_to_dict = self._generate_shallow_to_dict(
159 cls._traverse_internals, "_generated_shallow_to_dict_traversal"
160 )
161
162 cls._generated_shallow_to_dict_traversal = shallow_to_dict # type: ignore # noqa: E501
163 return shallow_to_dict(self)
164
165 def _shallow_copy_to(self, other: Self) -> None:
166 cls = self.__class__
167
168 shallow_copy: Callable[[Self, Self], None]
169 try:
170 shallow_copy = cls.__dict__["_generated_shallow_copy_traversal"]
171 except KeyError:
172 shallow_copy = self._generate_shallow_copy(
173 cls._traverse_internals, "_generated_shallow_copy_traversal"
174 )
175
176 cls._generated_shallow_copy_traversal = shallow_copy # type: ignore # noqa: E501
177 shallow_copy(self, other)
178
179 def _clone(self, **kw: Any) -> Self:
180 """Create a shallow copy"""
181 c = self.__class__.__new__(self.__class__)
182 self._shallow_copy_to(c)
183 return c
184
185
186class GenerativeOnTraversal(HasShallowCopy):
187 """Supplies Generative behavior but making use of traversals to shallow
188 copy.
189
190 .. seealso::
191
192 :class:`sqlalchemy.sql.base.Generative`
193
194
195 """
196
197 __slots__ = ()
198
199 def _generate(self) -> Self:
200 cls = self.__class__
201 s = cls.__new__(cls)
202 self._shallow_copy_to(s)
203 return s
204
205
206def _clone(element, **kw):
207 return element._clone()
208
209
210class HasCopyInternals(HasTraverseInternals):
211 __slots__ = ()
212
213 def _clone(self, **kw):
214 raise NotImplementedError()
215
216 def _copy_internals(
217 self, *, omit_attrs: Iterable[str] = (), **kw: Any
218 ) -> None:
219 """Reassign internal elements to be clones of themselves.
220
221 Called during a copy-and-traverse operation on newly
222 shallow-copied elements to create a deep copy.
223
224 The given clone function should be used, which may be applying
225 additional transformations to the element (i.e. replacement
226 traversal, cloned traversal, annotations).
227
228 """
229
230 try:
231 traverse_internals = self._traverse_internals
232 except AttributeError:
233 # user-defined classes may not have a _traverse_internals
234 return
235
236 for attrname, obj, meth in _copy_internals.run_generated_dispatch(
237 self, traverse_internals, "_generated_copy_internals_traversal"
238 ):
239 if attrname in omit_attrs:
240 continue
241
242 if obj is not None:
243 result = meth(attrname, self, obj, **kw)
244 if result is not None:
245 setattr(self, attrname, result)
246
247
248class _CopyInternalsTraversal(HasTraversalDispatch):
249 """Generate a _copy_internals internal traversal dispatch for classes
250 with a _traverse_internals collection."""
251
252 def visit_clauseelement(
253 self, attrname, parent, element, clone=_clone, **kw
254 ):
255 return clone(element, **kw)
256
257 def visit_clauseelement_list(
258 self, attrname, parent, element, clone=_clone, **kw
259 ):
260 return [clone(clause, **kw) for clause in element]
261
262 def visit_clauseelement_tuple(
263 self, attrname, parent, element, clone=_clone, **kw
264 ):
265 return tuple([clone(clause, **kw) for clause in element])
266
267 def visit_executable_options(
268 self, attrname, parent, element, clone=_clone, **kw
269 ):
270 return tuple([clone(clause, **kw) for clause in element])
271
272 def visit_clauseelement_unordered_set(
273 self, attrname, parent, element, clone=_clone, **kw
274 ):
275 return {clone(clause, **kw) for clause in element}
276
277 def visit_clauseelement_tuples(
278 self, attrname, parent, element, clone=_clone, **kw
279 ):
280 return [
281 tuple(clone(tup_elem, **kw) for tup_elem in elem)
282 for elem in element
283 ]
284
285 def visit_string_clauseelement_dict(
286 self, attrname, parent, element, clone=_clone, **kw
287 ):
288 return {key: clone(value, **kw) for key, value in element.items()}
289
290 def visit_setup_join_tuple(
291 self, attrname, parent, element, clone=_clone, **kw
292 ):
293 return tuple(
294 (
295 clone(target, **kw) if target is not None else None,
296 clone(onclause, **kw) if onclause is not None else None,
297 clone(from_, **kw) if from_ is not None else None,
298 flags,
299 )
300 for (target, onclause, from_, flags) in element
301 )
302
303 def visit_memoized_select_entities(self, attrname, parent, element, **kw):
304 return self.visit_clauseelement_tuple(attrname, parent, element, **kw)
305
306 def visit_dml_ordered_values(
307 self, attrname, parent, element, clone=_clone, **kw
308 ):
309 # sequence of 2-tuples
310 return [
311 (
312 (
313 clone(key, **kw)
314 if hasattr(key, "__clause_element__")
315 else key
316 ),
317 clone(value, **kw),
318 )
319 for key, value in element
320 ]
321
322 def visit_dml_values(self, attrname, parent, element, clone=_clone, **kw):
323 return {
324 (
325 clone(key, **kw) if hasattr(key, "__clause_element__") else key
326 ): clone(value, **kw)
327 for key, value in element.items()
328 }
329
330 def visit_dml_multi_values(
331 self, attrname, parent, element, clone=_clone, **kw
332 ):
333 # sequence of sequences, each sequence contains a list/dict/tuple
334
335 def copy(elem):
336 if isinstance(elem, (list, tuple)):
337 return [
338 (
339 clone(value, **kw)
340 if hasattr(value, "__clause_element__")
341 else value
342 )
343 for value in elem
344 ]
345 elif isinstance(elem, dict):
346 return {
347 (
348 clone(key, **kw)
349 if hasattr(key, "__clause_element__")
350 else key
351 ): (
352 clone(value, **kw)
353 if hasattr(value, "__clause_element__")
354 else value
355 )
356 for key, value in elem.items()
357 }
358 else:
359 # TODO: use abc classes
360 assert False
361
362 return [
363 [copy(sub_element) for sub_element in sequence]
364 for sequence in element
365 ]
366
367 def visit_propagate_attrs(
368 self, attrname, parent, element, clone=_clone, **kw
369 ):
370 return element
371
372
373_copy_internals = _CopyInternalsTraversal()
374
375
376def _flatten_clauseelement(element):
377 while hasattr(element, "__clause_element__") and not getattr(
378 element, "is_clause_element", False
379 ):
380 element = element.__clause_element__()
381
382 return element
383
384
385class _GetChildrenTraversal(HasTraversalDispatch):
386 """Generate a _children_traversal internal traversal dispatch for classes
387 with a _traverse_internals collection."""
388
389 def visit_has_cache_key(self, element, **kw):
390 # the GetChildren traversal refers explicitly to ClauseElement
391 # structures. Within these, a plain HasCacheKey is not a
392 # ClauseElement, so don't include these.
393 return ()
394
395 def visit_clauseelement(self, element, **kw):
396 return (element,)
397
398 def visit_clauseelement_list(self, element, **kw):
399 return element
400
401 def visit_clauseelement_tuple(self, element, **kw):
402 return element
403
404 def visit_clauseelement_tuples(self, element, **kw):
405 return itertools.chain.from_iterable(element)
406
407 def visit_fromclause_canonical_column_collection(self, element, **kw):
408 return ()
409
410 def visit_string_clauseelement_dict(self, element, **kw):
411 return element.values()
412
413 def visit_fromclause_ordered_set(self, element, **kw):
414 return element
415
416 def visit_clauseelement_unordered_set(self, element, **kw):
417 return element
418
419 def visit_setup_join_tuple(self, element, **kw):
420 for target, onclause, from_, flags in element:
421 if from_ is not None:
422 yield from_
423
424 if not isinstance(target, str):
425 yield _flatten_clauseelement(target)
426
427 if onclause is not None and not isinstance(onclause, str):
428 yield _flatten_clauseelement(onclause)
429
430 def visit_memoized_select_entities(self, element, **kw):
431 return self.visit_clauseelement_tuple(element, **kw)
432
433 def visit_dml_ordered_values(self, element, **kw):
434 for k, v in element:
435 if hasattr(k, "__clause_element__"):
436 yield k
437 yield v
438
439 def visit_dml_values(self, element, **kw):
440 expr_values = {k for k in element if hasattr(k, "__clause_element__")}
441 str_values = expr_values.symmetric_difference(element)
442
443 for k in sorted(str_values):
444 yield element[k]
445 for k in expr_values:
446 yield k
447 yield element[k]
448
449 def visit_dml_multi_values(self, element, **kw):
450 return ()
451
452 def visit_propagate_attrs(self, element, **kw):
453 return ()
454
455
456_get_children = _GetChildrenTraversal()
457
458
459@util.preload_module("sqlalchemy.sql.elements")
460def _resolve_name_for_compare(element, name, anon_map, **kw):
461 if isinstance(name, util.preloaded.sql_elements._anonymous_label):
462 name = name.apply_map(anon_map)
463
464 return name
465
466
467class TraversalComparatorStrategy(HasTraversalDispatch, util.MemoizedSlots):
468 __slots__ = "stack", "cache", "anon_map"
469
470 def __init__(self):
471 self.stack: Deque[
472 Tuple[
473 Optional[ExternallyTraversible],
474 Optional[ExternallyTraversible],
475 ]
476 ] = deque()
477 self.cache = set()
478
479 def _memoized_attr_anon_map(self):
480 return (anon_map(), anon_map())
481
482 def compare(
483 self,
484 obj1: ExternallyTraversible,
485 obj2: ExternallyTraversible,
486 **kw: Any,
487 ) -> bool:
488 stack = self.stack
489 cache = self.cache
490
491 compare_annotations = kw.get("compare_annotations", False)
492
493 stack.append((obj1, obj2))
494
495 while stack:
496 left, right = stack.popleft()
497
498 if left is right:
499 continue
500 elif left is None or right is None:
501 # we know they are different so no match
502 return False
503 elif (left, right) in cache:
504 continue
505 cache.add((left, right))
506
507 visit_name = left.__visit_name__
508 if visit_name != right.__visit_name__:
509 return False
510
511 meth = getattr(self, "compare_%s" % visit_name, None)
512
513 if meth:
514 attributes_compared = meth(left, right, **kw)
515 if attributes_compared is COMPARE_FAILED:
516 return False
517 elif attributes_compared is SKIP_TRAVERSE:
518 continue
519
520 # attributes_compared is returned as a list of attribute
521 # names that were "handled" by the comparison method above.
522 # remaining attribute names in the _traverse_internals
523 # will be compared.
524 else:
525 attributes_compared = ()
526
527 for (
528 (left_attrname, left_visit_sym),
529 (right_attrname, right_visit_sym),
530 ) in zip_longest(
531 left._traverse_internals,
532 right._traverse_internals,
533 fillvalue=(None, None),
534 ):
535 if not compare_annotations and (
536 (left_attrname == "_annotations")
537 or (right_attrname == "_annotations")
538 ):
539 continue
540
541 if (
542 left_attrname != right_attrname
543 or left_visit_sym is not right_visit_sym
544 ):
545 return False
546 elif left_attrname in attributes_compared:
547 continue
548
549 assert left_visit_sym is not None
550 assert left_attrname is not None
551 assert right_attrname is not None
552
553 dispatch = self.dispatch(left_visit_sym)
554 assert dispatch is not None, (
555 f"{self.__class__} has no dispatch for "
556 f"'{self._dispatch_lookup[left_visit_sym]}'"
557 )
558 left_child = operator.attrgetter(left_attrname)(left)
559 right_child = operator.attrgetter(right_attrname)(right)
560 if left_child is None:
561 if right_child is not None:
562 return False
563 else:
564 continue
565 elif right_child is None:
566 return False
567
568 comparison = dispatch(
569 left_attrname, left, left_child, right, right_child, **kw
570 )
571 if comparison is COMPARE_FAILED:
572 return False
573
574 return True
575
576 def compare_inner(self, obj1, obj2, **kw):
577 comparator = self.__class__()
578 return comparator.compare(obj1, obj2, **kw)
579
580 def visit_has_cache_key(
581 self, attrname, left_parent, left, right_parent, right, **kw
582 ):
583 if left._gen_cache_key(self.anon_map[0], []) != right._gen_cache_key(
584 self.anon_map[1], []
585 ):
586 return COMPARE_FAILED
587
588 def visit_propagate_attrs(
589 self, attrname, left_parent, left, right_parent, right, **kw
590 ):
591 return self.compare_inner(
592 left.get("plugin_subject", None), right.get("plugin_subject", None)
593 )
594
595 def visit_has_cache_key_list(
596 self, attrname, left_parent, left, right_parent, right, **kw
597 ):
598 for l, r in zip_longest(left, right, fillvalue=None):
599 if l is None:
600 if r is not None:
601 return COMPARE_FAILED
602 else:
603 continue
604 elif r is None:
605 return COMPARE_FAILED
606
607 if l._gen_cache_key(self.anon_map[0], []) != r._gen_cache_key(
608 self.anon_map[1], []
609 ):
610 return COMPARE_FAILED
611
612 def visit_executable_options(
613 self, attrname, left_parent, left, right_parent, right, **kw
614 ):
615 for l, r in zip_longest(left, right, fillvalue=None):
616 if l is None:
617 if r is not None:
618 return COMPARE_FAILED
619 else:
620 continue
621 elif r is None:
622 return COMPARE_FAILED
623
624 if (
625 l._gen_cache_key(self.anon_map[0], [])
626 if l._is_has_cache_key
627 else l
628 ) != (
629 r._gen_cache_key(self.anon_map[1], [])
630 if r._is_has_cache_key
631 else r
632 ):
633 return COMPARE_FAILED
634
635 def visit_clauseelement(
636 self, attrname, left_parent, left, right_parent, right, **kw
637 ):
638 self.stack.append((left, right))
639
640 def visit_fromclause_canonical_column_collection(
641 self, attrname, left_parent, left, right_parent, right, **kw
642 ):
643 for lcol, rcol in zip_longest(left, right, fillvalue=None):
644 self.stack.append((lcol, rcol))
645
646 def visit_fromclause_derived_column_collection(
647 self, attrname, left_parent, left, right_parent, right, **kw
648 ):
649 pass
650
651 def visit_string_clauseelement_dict(
652 self, attrname, left_parent, left, right_parent, right, **kw
653 ):
654 for lstr, rstr in zip_longest(
655 sorted(left), sorted(right), fillvalue=None
656 ):
657 if lstr != rstr:
658 return COMPARE_FAILED
659 self.stack.append((left[lstr], right[rstr]))
660
661 def visit_clauseelement_tuples(
662 self, attrname, left_parent, left, right_parent, right, **kw
663 ):
664 for ltup, rtup in zip_longest(left, right, fillvalue=None):
665 if ltup is None or rtup is None:
666 return COMPARE_FAILED
667
668 for l, r in zip_longest(ltup, rtup, fillvalue=None):
669 self.stack.append((l, r))
670
671 def visit_clauseelement_list(
672 self, attrname, left_parent, left, right_parent, right, **kw
673 ):
674 for l, r in zip_longest(left, right, fillvalue=None):
675 self.stack.append((l, r))
676
677 def visit_clauseelement_tuple(
678 self, attrname, left_parent, left, right_parent, right, **kw
679 ):
680 for l, r in zip_longest(left, right, fillvalue=None):
681 self.stack.append((l, r))
682
683 def _compare_unordered_sequences(self, seq1, seq2, **kw):
684 if seq1 is None:
685 return seq2 is None
686
687 completed: Set[object] = set()
688 for clause in seq1:
689 for other_clause in set(seq2).difference(completed):
690 if self.compare_inner(clause, other_clause, **kw):
691 completed.add(other_clause)
692 break
693 return len(completed) == len(seq1) == len(seq2)
694
695 def visit_clauseelement_unordered_set(
696 self, attrname, left_parent, left, right_parent, right, **kw
697 ):
698 return self._compare_unordered_sequences(left, right, **kw)
699
700 def visit_fromclause_ordered_set(
701 self, attrname, left_parent, left, right_parent, right, **kw
702 ):
703 for l, r in zip_longest(left, right, fillvalue=None):
704 self.stack.append((l, r))
705
706 def visit_string(
707 self, attrname, left_parent, left, right_parent, right, **kw
708 ):
709 return left == right
710
711 def visit_string_list(
712 self, attrname, left_parent, left, right_parent, right, **kw
713 ):
714 return left == right
715
716 def visit_string_multi_dict(
717 self, attrname, left_parent, left, right_parent, right, **kw
718 ):
719 for lk, rk in zip_longest(
720 sorted(left.keys()), sorted(right.keys()), fillvalue=(None, None)
721 ):
722 if lk != rk:
723 return COMPARE_FAILED
724
725 lv, rv = left[lk], right[rk]
726
727 lhc = isinstance(left, HasCacheKey)
728 rhc = isinstance(right, HasCacheKey)
729 if lhc and rhc:
730 if lv._gen_cache_key(
731 self.anon_map[0], []
732 ) != rv._gen_cache_key(self.anon_map[1], []):
733 return COMPARE_FAILED
734 elif lhc != rhc:
735 return COMPARE_FAILED
736 elif lv != rv:
737 return COMPARE_FAILED
738
739 def visit_multi(
740 self, attrname, left_parent, left, right_parent, right, **kw
741 ):
742 lhc = isinstance(left, HasCacheKey)
743 rhc = isinstance(right, HasCacheKey)
744 if lhc and rhc:
745 if left._gen_cache_key(
746 self.anon_map[0], []
747 ) != right._gen_cache_key(self.anon_map[1], []):
748 return COMPARE_FAILED
749 elif lhc != rhc:
750 return COMPARE_FAILED
751 else:
752 return left == right
753
754 def visit_anon_name(
755 self, attrname, left_parent, left, right_parent, right, **kw
756 ):
757 return _resolve_name_for_compare(
758 left_parent, left, self.anon_map[0], **kw
759 ) == _resolve_name_for_compare(
760 right_parent, right, self.anon_map[1], **kw
761 )
762
763 def visit_boolean(
764 self, attrname, left_parent, left, right_parent, right, **kw
765 ):
766 return left == right
767
768 def visit_operator(
769 self, attrname, left_parent, left, right_parent, right, **kw
770 ):
771 return left == right
772
773 def visit_type(
774 self, attrname, left_parent, left, right_parent, right, **kw
775 ):
776 return left._compare_type_affinity(right)
777
778 def visit_plain_dict(
779 self, attrname, left_parent, left, right_parent, right, **kw
780 ):
781 return left == right
782
783 def visit_dialect_options(
784 self, attrname, left_parent, left, right_parent, right, **kw
785 ):
786 return left == right
787
788 def visit_annotations_key(
789 self, attrname, left_parent, left, right_parent, right, **kw
790 ):
791 if left and right:
792 return (
793 left_parent._annotations_cache_key
794 == right_parent._annotations_cache_key
795 )
796 else:
797 return left == right
798
799 def visit_with_context_options(
800 self, attrname, left_parent, left, right_parent, right, **kw
801 ):
802 return tuple((fn.__code__, c_key) for fn, c_key in left) == tuple(
803 (fn.__code__, c_key) for fn, c_key in right
804 )
805
806 def visit_plain_obj(
807 self, attrname, left_parent, left, right_parent, right, **kw
808 ):
809 return left == right
810
811 def visit_named_ddl_element(
812 self, attrname, left_parent, left, right_parent, right, **kw
813 ):
814 if left is None:
815 if right is not None:
816 return COMPARE_FAILED
817
818 return left.name == right.name
819
820 def visit_prefix_sequence(
821 self, attrname, left_parent, left, right_parent, right, **kw
822 ):
823 for (l_clause, l_str), (r_clause, r_str) in zip_longest(
824 left, right, fillvalue=(None, None)
825 ):
826 if l_str != r_str:
827 return COMPARE_FAILED
828 else:
829 self.stack.append((l_clause, r_clause))
830
831 def visit_setup_join_tuple(
832 self, attrname, left_parent, left, right_parent, right, **kw
833 ):
834 # TODO: look at attrname for "legacy_join" and use different structure
835 for (
836 (l_target, l_onclause, l_from, l_flags),
837 (r_target, r_onclause, r_from, r_flags),
838 ) in zip_longest(left, right, fillvalue=(None, None, None, None)):
839 if l_flags != r_flags:
840 return COMPARE_FAILED
841 self.stack.append((l_target, r_target))
842 self.stack.append((l_onclause, r_onclause))
843 self.stack.append((l_from, r_from))
844
845 def visit_memoized_select_entities(
846 self, attrname, left_parent, left, right_parent, right, **kw
847 ):
848 return self.visit_clauseelement_tuple(
849 attrname, left_parent, left, right_parent, right, **kw
850 )
851
852 def visit_table_hint_list(
853 self, attrname, left_parent, left, right_parent, right, **kw
854 ):
855 left_keys = sorted(left, key=lambda elem: (elem[0].fullname, elem[1]))
856 right_keys = sorted(
857 right, key=lambda elem: (elem[0].fullname, elem[1])
858 )
859 for (ltable, ldialect), (rtable, rdialect) in zip_longest(
860 left_keys, right_keys, fillvalue=(None, None)
861 ):
862 if ldialect != rdialect:
863 return COMPARE_FAILED
864 elif left[(ltable, ldialect)] != right[(rtable, rdialect)]:
865 return COMPARE_FAILED
866 else:
867 self.stack.append((ltable, rtable))
868
869 def visit_statement_hint_list(
870 self, attrname, left_parent, left, right_parent, right, **kw
871 ):
872 return left == right
873
874 def visit_unknown_structure(
875 self, attrname, left_parent, left, right_parent, right, **kw
876 ):
877 raise NotImplementedError()
878
879 def visit_dml_ordered_values(
880 self, attrname, left_parent, left, right_parent, right, **kw
881 ):
882 # sequence of tuple pairs
883
884 for (lk, lv), (rk, rv) in zip_longest(
885 left, right, fillvalue=(None, None)
886 ):
887 if not self._compare_dml_values_or_ce(lk, rk, **kw):
888 return COMPARE_FAILED
889
890 def _compare_dml_values_or_ce(self, lv, rv, **kw):
891 lvce = hasattr(lv, "__clause_element__")
892 rvce = hasattr(rv, "__clause_element__")
893 if lvce != rvce:
894 return False
895 elif lvce and not self.compare_inner(lv, rv, **kw):
896 return False
897 elif not lvce and lv != rv:
898 return False
899 elif not self.compare_inner(lv, rv, **kw):
900 return False
901
902 return True
903
904 def visit_dml_values(
905 self, attrname, left_parent, left, right_parent, right, **kw
906 ):
907 if left is None or right is None or len(left) != len(right):
908 return COMPARE_FAILED
909
910 if isinstance(left, collections_abc.Sequence):
911 for lv, rv in zip(left, right):
912 if not self._compare_dml_values_or_ce(lv, rv, **kw):
913 return COMPARE_FAILED
914 elif isinstance(right, collections_abc.Sequence):
915 return COMPARE_FAILED
916 else:
917 # dictionaries guaranteed to support insert ordering in
918 # py37 so that we can compare the keys in order. without
919 # this, we can't compare SQL expression keys because we don't
920 # know which key is which
921 for (lk, lv), (rk, rv) in zip(left.items(), right.items()):
922 if not self._compare_dml_values_or_ce(lk, rk, **kw):
923 return COMPARE_FAILED
924 if not self._compare_dml_values_or_ce(lv, rv, **kw):
925 return COMPARE_FAILED
926
927 def visit_dml_multi_values(
928 self, attrname, left_parent, left, right_parent, right, **kw
929 ):
930 for lseq, rseq in zip_longest(left, right, fillvalue=None):
931 if lseq is None or rseq is None:
932 return COMPARE_FAILED
933
934 for ld, rd in zip_longest(lseq, rseq, fillvalue=None):
935 if (
936 self.visit_dml_values(
937 attrname, left_parent, ld, right_parent, rd, **kw
938 )
939 is COMPARE_FAILED
940 ):
941 return COMPARE_FAILED
942
943 def compare_expression_clauselist(self, left, right, **kw):
944 if left.operator is right.operator:
945 if operators.is_associative(left.operator):
946 if self._compare_unordered_sequences(
947 left.clauses, right.clauses, **kw
948 ):
949 return ["operator", "clauses"]
950 else:
951 return COMPARE_FAILED
952 else:
953 return ["operator"]
954 else:
955 return COMPARE_FAILED
956
957 def compare_clauselist(self, left, right, **kw):
958 return self.compare_expression_clauselist(left, right, **kw)
959
960 def compare_binary(self, left, right, **kw):
961 if left.operator == right.operator:
962 if operators.is_commutative(left.operator):
963 if (
964 self.compare_inner(left.left, right.left, **kw)
965 and self.compare_inner(left.right, right.right, **kw)
966 ) or (
967 self.compare_inner(left.left, right.right, **kw)
968 and self.compare_inner(left.right, right.left, **kw)
969 ):
970 return ["operator", "negate", "left", "right"]
971 else:
972 return COMPARE_FAILED
973 else:
974 return ["operator", "negate"]
975 else:
976 return COMPARE_FAILED
977
978 def compare_bindparam(self, left, right, **kw):
979 compare_keys = kw.pop("compare_keys", True)
980 compare_values = kw.pop("compare_values", True)
981
982 if compare_values:
983 omit = []
984 else:
985 # this means, "skip these, we already compared"
986 omit = ["callable", "value"]
987
988 if not compare_keys:
989 omit.append("key")
990
991 return omit
992
993
994class ColIdentityComparatorStrategy(TraversalComparatorStrategy):
995 def compare_column_element(
996 self, left, right, use_proxies=True, equivalents=(), **kw
997 ):
998 """Compare ColumnElements using proxies and equivalent collections.
999
1000 This is a comparison strategy specific to the ORM.
1001 """
1002
1003 to_compare = (right,)
1004 if equivalents and right in equivalents:
1005 to_compare = equivalents[right].union(to_compare)
1006
1007 for oth in to_compare:
1008 if use_proxies and left.shares_lineage(oth):
1009 return SKIP_TRAVERSE
1010 elif hash(left) == hash(right):
1011 return SKIP_TRAVERSE
1012 else:
1013 return COMPARE_FAILED
1014
1015 def compare_column(self, left, right, **kw):
1016 return self.compare_column_element(left, right, **kw)
1017
1018 def compare_label(self, left, right, **kw):
1019 return self.compare_column_element(left, right, **kw)
1020
1021 def compare_table(self, left, right, **kw):
1022 # tables compare on identity, since it's not really feasible to
1023 # compare them column by column with the above rules
1024 return SKIP_TRAVERSE if left is right else COMPARE_FAILED