1# sql/util.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""High level utilities which build upon other modules here."""
10from __future__ import annotations
11
12from collections import deque
13import copy
14from itertools import chain
15import typing
16from typing import AbstractSet
17from typing import Any
18from typing import Callable
19from typing import cast
20from typing import Collection
21from typing import Dict
22from typing import Iterable
23from typing import Iterator
24from typing import List
25from typing import Optional
26from typing import overload
27from typing import Sequence
28from typing import Tuple
29from typing import TYPE_CHECKING
30from typing import TypeVar
31from typing import Union
32
33from . import coercions
34from . import operators
35from . import roles
36from . import visitors
37from ._typing import is_text_clause
38from .annotation import _deep_annotate as _deep_annotate # noqa: F401
39from .annotation import _deep_deannotate as _deep_deannotate # noqa: F401
40from .annotation import _shallow_annotate as _shallow_annotate # noqa: F401
41from .base import _expand_cloned
42from .base import _from_objects
43from .cache_key import HasCacheKey as HasCacheKey # noqa: F401
44from .ddl import sort_tables as sort_tables # noqa: F401
45from .elements import _find_columns as _find_columns
46from .elements import _label_reference
47from .elements import _textual_label_reference
48from .elements import BindParameter
49from .elements import ClauseElement
50from .elements import ColumnClause
51from .elements import ColumnElement
52from .elements import Grouping
53from .elements import KeyedColumnElement
54from .elements import Label
55from .elements import NamedColumn
56from .elements import Null
57from .elements import UnaryExpression
58from .schema import Column
59from .selectable import Alias
60from .selectable import FromClause
61from .selectable import FromGrouping
62from .selectable import Join
63from .selectable import ScalarSelect
64from .selectable import SelectBase
65from .selectable import TableClause
66from .visitors import _ET
67from .. import exc
68from .. import util
69from ..util.typing import Literal
70from ..util.typing import Protocol
71
72if typing.TYPE_CHECKING:
73 from ._typing import _EquivalentColumnMap
74 from ._typing import _LimitOffsetType
75 from ._typing import _TypeEngineArgument
76 from .elements import BinaryExpression
77 from .elements import TextClause
78 from .selectable import _JoinTargetElement
79 from .selectable import _SelectIterable
80 from .selectable import Selectable
81 from .visitors import _TraverseCallableType
82 from .visitors import ExternallyTraversible
83 from .visitors import ExternalTraversal
84 from ..engine.interfaces import _AnyExecuteParams
85 from ..engine.interfaces import _AnyMultiExecuteParams
86 from ..engine.interfaces import _AnySingleExecuteParams
87 from ..engine.interfaces import _CoreSingleExecuteParams
88 from ..engine.row import Row
89
90_CE = TypeVar("_CE", bound="ColumnElement[Any]")
91
92
93def join_condition(
94 a: FromClause,
95 b: FromClause,
96 a_subset: Optional[FromClause] = None,
97 consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
98) -> ColumnElement[bool]:
99 """Create a join condition between two tables or selectables.
100
101 e.g.::
102
103 join_condition(tablea, tableb)
104
105 would produce an expression along the lines of::
106
107 tablea.c.id == tableb.c.tablea_id
108
109 The join is determined based on the foreign key relationships
110 between the two selectables. If there are multiple ways
111 to join, or no way to join, an error is raised.
112
113 :param a_subset: An optional expression that is a sub-component
114 of ``a``. An attempt will be made to join to just this sub-component
115 first before looking at the full ``a`` construct, and if found
116 will be successful even if there are other ways to join to ``a``.
117 This allows the "right side" of a join to be passed thereby
118 providing a "natural join".
119
120 """
121 return Join._join_condition(
122 a,
123 b,
124 a_subset=a_subset,
125 consider_as_foreign_keys=consider_as_foreign_keys,
126 )
127
128
129def find_join_source(
130 clauses: List[FromClause], join_to: FromClause
131) -> List[int]:
132 """Given a list of FROM clauses and a selectable,
133 return the first index and element from the list of
134 clauses which can be joined against the selectable. returns
135 None, None if no match is found.
136
137 e.g.::
138
139 clause1 = table1.join(table2)
140 clause2 = table4.join(table5)
141
142 join_to = table2.join(table3)
143
144 find_join_source([clause1, clause2], join_to) == clause1
145
146 """
147
148 selectables = list(_from_objects(join_to))
149 idx = []
150 for i, f in enumerate(clauses):
151 for s in selectables:
152 if f.is_derived_from(s):
153 idx.append(i)
154 return idx
155
156
157def find_left_clause_that_matches_given(
158 clauses: Sequence[FromClause], join_from: FromClause
159) -> List[int]:
160 """Given a list of FROM clauses and a selectable,
161 return the indexes from the list of
162 clauses which is derived from the selectable.
163
164 """
165
166 selectables = list(_from_objects(join_from))
167 liberal_idx = []
168 for i, f in enumerate(clauses):
169 for s in selectables:
170 # basic check, if f is derived from s.
171 # this can be joins containing a table, or an aliased table
172 # or select statement matching to a table. This check
173 # will match a table to a selectable that is adapted from
174 # that table. With Query, this suits the case where a join
175 # is being made to an adapted entity
176 if f.is_derived_from(s):
177 liberal_idx.append(i)
178 break
179
180 # in an extremely small set of use cases, a join is being made where
181 # there are multiple FROM clauses where our target table is represented
182 # in more than one, such as embedded or similar. in this case, do
183 # another pass where we try to get a more exact match where we aren't
184 # looking at adaption relationships.
185 if len(liberal_idx) > 1:
186 conservative_idx = []
187 for idx in liberal_idx:
188 f = clauses[idx]
189 for s in selectables:
190 if set(surface_selectables(f)).intersection(
191 surface_selectables(s)
192 ):
193 conservative_idx.append(idx)
194 break
195 if conservative_idx:
196 return conservative_idx
197
198 return liberal_idx
199
200
201def find_left_clause_to_join_from(
202 clauses: Sequence[FromClause],
203 join_to: _JoinTargetElement,
204 onclause: Optional[ColumnElement[Any]],
205) -> List[int]:
206 """Given a list of FROM clauses, a selectable,
207 and optional ON clause, return a list of integer indexes from the
208 clauses list indicating the clauses that can be joined from.
209
210 The presence of an "onclause" indicates that at least one clause can
211 definitely be joined from; if the list of clauses is of length one
212 and the onclause is given, returns that index. If the list of clauses
213 is more than length one, and the onclause is given, attempts to locate
214 which clauses contain the same columns.
215
216 """
217 idx = []
218 selectables = set(_from_objects(join_to))
219
220 # if we are given more than one target clause to join
221 # from, use the onclause to provide a more specific answer.
222 # otherwise, don't try to limit, after all, "ON TRUE" is a valid
223 # on clause
224 if len(clauses) > 1 and onclause is not None:
225 resolve_ambiguity = True
226 cols_in_onclause = _find_columns(onclause)
227 else:
228 resolve_ambiguity = False
229 cols_in_onclause = None
230
231 for i, f in enumerate(clauses):
232 for s in selectables.difference([f]):
233 if resolve_ambiguity:
234 assert cols_in_onclause is not None
235 if set(f.c).union(s.c).issuperset(cols_in_onclause):
236 idx.append(i)
237 break
238 elif onclause is not None or Join._can_join(f, s):
239 idx.append(i)
240 break
241
242 if len(idx) > 1:
243 # this is the same "hide froms" logic from
244 # Selectable._get_display_froms
245 toremove = set(
246 chain(*[_expand_cloned(f._hide_froms) for f in clauses])
247 )
248 idx = [i for i in idx if clauses[i] not in toremove]
249
250 # onclause was given and none of them resolved, so assume
251 # all indexes can match
252 if not idx and onclause is not None:
253 return list(range(len(clauses)))
254 else:
255 return idx
256
257
258def visit_binary_product(
259 fn: Callable[
260 [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
261 ],
262 expr: ColumnElement[Any],
263) -> None:
264 """Produce a traversal of the given expression, delivering
265 column comparisons to the given function.
266
267 The function is of the form::
268
269 def my_fn(binary, left, right): ...
270
271 For each binary expression located which has a
272 comparison operator, the product of "left" and
273 "right" will be delivered to that function,
274 in terms of that binary.
275
276 Hence an expression like::
277
278 and_((a + b) == q + func.sum(e + f), j == r)
279
280 would have the traversal:
281
282 .. sourcecode:: text
283
284 a <eq> q
285 a <eq> e
286 a <eq> f
287 b <eq> q
288 b <eq> e
289 b <eq> f
290 j <eq> r
291
292 That is, every combination of "left" and
293 "right" that doesn't further contain
294 a binary comparison is passed as pairs.
295
296 """
297 stack: List[BinaryExpression[Any]] = []
298
299 def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
300 if isinstance(element, ScalarSelect):
301 # we don't want to dig into correlated subqueries,
302 # those are just column elements by themselves
303 yield element
304 elif element.__visit_name__ == "binary" and operators.is_comparison(
305 element.operator # type: ignore
306 ):
307 stack.insert(0, element) # type: ignore
308 for l in visit(element.left): # type: ignore
309 for r in visit(element.right): # type: ignore
310 fn(stack[0], l, r)
311 stack.pop(0)
312 for elem in element.get_children():
313 visit(elem)
314 else:
315 if isinstance(element, ColumnClause):
316 yield element
317 for elem in element.get_children():
318 yield from visit(elem)
319
320 list(visit(expr))
321 visit = None # type: ignore # remove gc cycles
322
323
324def find_tables(
325 clause: ClauseElement,
326 *,
327 check_columns: bool = False,
328 include_aliases: bool = False,
329 include_joins: bool = False,
330 include_selects: bool = False,
331 include_crud: bool = False,
332) -> List[TableClause]:
333 """locate Table objects within the given expression."""
334
335 tables: List[TableClause] = []
336 _visitors: Dict[str, _TraverseCallableType[Any]] = {}
337
338 if include_selects:
339 _visitors["select"] = _visitors["compound_select"] = tables.append
340
341 if include_joins:
342 _visitors["join"] = tables.append
343
344 if include_aliases:
345 _visitors["alias"] = _visitors["subquery"] = _visitors[
346 "tablesample"
347 ] = _visitors["lateral"] = tables.append
348
349 if include_crud:
350 _visitors["insert"] = _visitors["update"] = _visitors["delete"] = (
351 lambda ent: tables.append(ent.table)
352 )
353
354 if check_columns:
355
356 def visit_column(column):
357 tables.append(column.table)
358
359 _visitors["column"] = visit_column
360
361 _visitors["table"] = tables.append
362
363 visitors.traverse(clause, {}, _visitors)
364 return tables
365
366
367def unwrap_order_by(clause: Any) -> Any:
368 """Break up an 'order by' expression into individual column-expressions,
369 without DESC/ASC/NULLS FIRST/NULLS LAST"""
370
371 cols = util.column_set()
372 result = []
373 stack = deque([clause])
374
375 # examples
376 # column -> ASC/DESC == column
377 # column -> ASC/DESC -> label == column
378 # column -> label -> ASC/DESC -> label == column
379 # scalar_select -> label -> ASC/DESC == scalar_select -> label
380
381 while stack:
382 t = stack.popleft()
383 if isinstance(t, ColumnElement) and (
384 not isinstance(t, UnaryExpression)
385 or not operators.is_ordering_modifier(t.modifier) # type: ignore
386 ):
387 if isinstance(t, Label) and not isinstance(
388 t.element, ScalarSelect
389 ):
390 t = t.element
391
392 if isinstance(t, Grouping):
393 t = t.element
394
395 stack.append(t)
396 continue
397 elif isinstance(t, _label_reference):
398 t = t.element
399
400 stack.append(t)
401 continue
402 if isinstance(t, (_textual_label_reference)):
403 continue
404 if t not in cols:
405 cols.add(t)
406 result.append(t)
407
408 else:
409 for c in t.get_children():
410 stack.append(c)
411 return result
412
413
414def unwrap_label_reference(element):
415 def replace(
416 element: ExternallyTraversible, **kw: Any
417 ) -> Optional[ExternallyTraversible]:
418 if isinstance(element, _label_reference):
419 return element.element
420 elif isinstance(element, _textual_label_reference):
421 assert False, "can't unwrap a textual label reference"
422 return None
423
424 return visitors.replacement_traverse(element, {}, replace)
425
426
427def expand_column_list_from_order_by(collist, order_by):
428 """Given the columns clause and ORDER BY of a selectable,
429 return a list of column expressions that can be added to the collist
430 corresponding to the ORDER BY, without repeating those already
431 in the collist.
432
433 """
434 cols_already_present = {
435 col.element if col._order_by_label_element is not None else col
436 for col in collist
437 }
438
439 to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
440
441 return [col for col in to_look_for if col not in cols_already_present]
442
443
444def clause_is_present(clause, search):
445 """Given a target clause and a second to search within, return True
446 if the target is plainly present in the search without any
447 subqueries or aliases involved.
448
449 Basically descends through Joins.
450
451 """
452
453 for elem in surface_selectables(search):
454 if clause == elem: # use == here so that Annotated's compare
455 return True
456 else:
457 return False
458
459
460def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
461 if isinstance(clause, Join):
462 yield from tables_from_leftmost(clause.left)
463 yield from tables_from_leftmost(clause.right)
464 elif isinstance(clause, FromGrouping):
465 yield from tables_from_leftmost(clause.element)
466 else:
467 yield clause
468
469
470def surface_selectables(clause):
471 stack = [clause]
472 while stack:
473 elem = stack.pop()
474 yield elem
475 if isinstance(elem, Join):
476 stack.extend((elem.left, elem.right))
477 elif isinstance(elem, FromGrouping):
478 stack.append(elem.element)
479
480
481def surface_selectables_only(clause: ClauseElement) -> Iterator[ClauseElement]:
482 stack = [clause]
483 while stack:
484 elem = stack.pop()
485 if isinstance(elem, (TableClause, Alias)):
486 yield elem
487 if isinstance(elem, Join):
488 stack.extend((elem.left, elem.right))
489 elif isinstance(elem, FromGrouping):
490 stack.append(elem.element)
491 elif isinstance(elem, ColumnClause):
492 if elem.table is not None:
493 stack.append(elem.table)
494 else:
495 yield elem
496 elif elem is not None:
497 yield elem
498
499
500def extract_first_column_annotation(column, annotation_name):
501 filter_ = (FromGrouping, SelectBase)
502
503 stack = deque([column])
504 while stack:
505 elem = stack.popleft()
506 if annotation_name in elem._annotations:
507 return elem._annotations[annotation_name]
508 for sub in elem.get_children():
509 if isinstance(sub, filter_):
510 continue
511 stack.append(sub)
512 return None
513
514
515def selectables_overlap(left: FromClause, right: FromClause) -> bool:
516 """Return True if left/right have some overlapping selectable"""
517
518 return bool(
519 set(surface_selectables(left)).intersection(surface_selectables(right))
520 )
521
522
523def bind_values(clause):
524 """Return an ordered list of "bound" values in the given clause.
525
526 E.g.::
527
528 >>> expr = and_(table.c.foo == 5, table.c.foo == 7)
529 >>> bind_values(expr)
530 [5, 7]
531 """
532
533 v = []
534
535 def visit_bindparam(bind):
536 v.append(bind.effective_value)
537
538 visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
539 return v
540
541
542def _quote_ddl_expr(element):
543 if isinstance(element, str):
544 element = element.replace("'", "''")
545 return "'%s'" % element
546 else:
547 return repr(element)
548
549
550class _repr_base:
551 _LIST: int = 0
552 _TUPLE: int = 1
553 _DICT: int = 2
554
555 __slots__ = ("max_chars",)
556
557 max_chars: int
558
559 def trunc(self, value: Any) -> str:
560 rep = repr(value)
561 lenrep = len(rep)
562 if lenrep > self.max_chars:
563 segment_length = self.max_chars // 2
564 rep = (
565 rep[0:segment_length]
566 + (
567 " ... (%d characters truncated) ... "
568 % (lenrep - self.max_chars)
569 )
570 + rep[-segment_length:]
571 )
572 return rep
573
574
575def _repr_single_value(value):
576 rp = _repr_base()
577 rp.max_chars = 300
578 return rp.trunc(value)
579
580
581class _repr_row(_repr_base):
582 """Provide a string view of a row."""
583
584 __slots__ = ("row",)
585
586 def __init__(self, row: Row[Any], max_chars: int = 300):
587 self.row = row
588 self.max_chars = max_chars
589
590 def __repr__(self) -> str:
591 trunc = self.trunc
592 return "(%s%s)" % (
593 ", ".join(trunc(value) for value in self.row),
594 "," if len(self.row) == 1 else "",
595 )
596
597
598class _long_statement(str):
599 def __str__(self) -> str:
600 lself = len(self)
601 if lself > 500:
602 lleft = 250
603 lright = 100
604 trunc = lself - lleft - lright
605 return (
606 f"{self[0:lleft]} ... {trunc} "
607 f"characters truncated ... {self[-lright:]}"
608 )
609 else:
610 return str.__str__(self)
611
612
613class _repr_params(_repr_base):
614 """Provide a string view of bound parameters.
615
616 Truncates display to a given number of 'multi' parameter sets,
617 as well as long values to a given number of characters.
618
619 """
620
621 __slots__ = "params", "batches", "ismulti", "max_params"
622
623 def __init__(
624 self,
625 params: Optional[_AnyExecuteParams],
626 batches: int,
627 max_params: int = 100,
628 max_chars: int = 300,
629 ismulti: Optional[bool] = None,
630 ):
631 self.params = params
632 self.ismulti = ismulti
633 self.batches = batches
634 self.max_chars = max_chars
635 self.max_params = max_params
636
637 def __repr__(self) -> str:
638 if self.ismulti is None:
639 return self.trunc(self.params)
640
641 if isinstance(self.params, list):
642 typ = self._LIST
643
644 elif isinstance(self.params, tuple):
645 typ = self._TUPLE
646 elif isinstance(self.params, dict):
647 typ = self._DICT
648 else:
649 return self.trunc(self.params)
650
651 if self.ismulti:
652 multi_params = cast(
653 "_AnyMultiExecuteParams",
654 self.params,
655 )
656
657 if len(self.params) > self.batches:
658 msg = (
659 " ... displaying %i of %i total bound parameter sets ... "
660 )
661 return " ".join(
662 (
663 self._repr_multi(
664 multi_params[: self.batches - 2],
665 typ,
666 )[0:-1],
667 msg % (self.batches, len(self.params)),
668 self._repr_multi(multi_params[-2:], typ)[1:],
669 )
670 )
671 else:
672 return self._repr_multi(multi_params, typ)
673 else:
674 return self._repr_params(
675 cast(
676 "_AnySingleExecuteParams",
677 self.params,
678 ),
679 typ,
680 )
681
682 def _repr_multi(
683 self,
684 multi_params: _AnyMultiExecuteParams,
685 typ: int,
686 ) -> str:
687 if multi_params:
688 if isinstance(multi_params[0], list):
689 elem_type = self._LIST
690 elif isinstance(multi_params[0], tuple):
691 elem_type = self._TUPLE
692 elif isinstance(multi_params[0], dict):
693 elem_type = self._DICT
694 else:
695 assert False, "Unknown parameter type %s" % (
696 type(multi_params[0])
697 )
698
699 elements = ", ".join(
700 self._repr_params(params, elem_type) for params in multi_params
701 )
702 else:
703 elements = ""
704
705 if typ == self._LIST:
706 return "[%s]" % elements
707 else:
708 return "(%s)" % elements
709
710 def _get_batches(self, params: Iterable[Any]) -> Any:
711 lparams = list(params)
712 lenparams = len(lparams)
713 if lenparams > self.max_params:
714 lleft = self.max_params // 2
715 return (
716 lparams[0:lleft],
717 lparams[-lleft:],
718 lenparams - self.max_params,
719 )
720 else:
721 return lparams, None, None
722
723 def _repr_params(
724 self,
725 params: _AnySingleExecuteParams,
726 typ: int,
727 ) -> str:
728 if typ is self._DICT:
729 return self._repr_param_dict(
730 cast("_CoreSingleExecuteParams", params)
731 )
732 elif typ is self._TUPLE:
733 return self._repr_param_tuple(cast("Sequence[Any]", params))
734 else:
735 return self._repr_param_list(params)
736
737 def _repr_param_dict(self, params: _CoreSingleExecuteParams) -> str:
738 trunc = self.trunc
739 (
740 items_first_batch,
741 items_second_batch,
742 trunclen,
743 ) = self._get_batches(params.items())
744
745 if items_second_batch:
746 text = "{%s" % (
747 ", ".join(
748 f"{key!r}: {trunc(value)}"
749 for key, value in items_first_batch
750 )
751 )
752 text += f" ... {trunclen} parameters truncated ... "
753 text += "%s}" % (
754 ", ".join(
755 f"{key!r}: {trunc(value)}"
756 for key, value in items_second_batch
757 )
758 )
759 else:
760 text = "{%s}" % (
761 ", ".join(
762 f"{key!r}: {trunc(value)}"
763 for key, value in items_first_batch
764 )
765 )
766 return text
767
768 def _repr_param_tuple(self, params: Sequence[Any]) -> str:
769 trunc = self.trunc
770
771 (
772 items_first_batch,
773 items_second_batch,
774 trunclen,
775 ) = self._get_batches(params)
776
777 if items_second_batch:
778 text = "(%s" % (
779 ", ".join(trunc(value) for value in items_first_batch)
780 )
781 text += f" ... {trunclen} parameters truncated ... "
782 text += "%s)" % (
783 ", ".join(trunc(value) for value in items_second_batch),
784 )
785 else:
786 text = "(%s%s)" % (
787 ", ".join(trunc(value) for value in items_first_batch),
788 "," if len(items_first_batch) == 1 else "",
789 )
790 return text
791
792 def _repr_param_list(self, params: _AnySingleExecuteParams) -> str:
793 trunc = self.trunc
794 (
795 items_first_batch,
796 items_second_batch,
797 trunclen,
798 ) = self._get_batches(params)
799
800 if items_second_batch:
801 text = "[%s" % (
802 ", ".join(trunc(value) for value in items_first_batch)
803 )
804 text += f" ... {trunclen} parameters truncated ... "
805 text += "%s]" % (
806 ", ".join(trunc(value) for value in items_second_batch)
807 )
808 else:
809 text = "[%s]" % (
810 ", ".join(trunc(value) for value in items_first_batch)
811 )
812 return text
813
814
815def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
816 """given criterion containing bind params, convert selected elements
817 to IS NULL.
818
819 """
820
821 def visit_binary(binary):
822 if (
823 isinstance(binary.left, BindParameter)
824 and binary.left._identifying_key in nulls
825 ):
826 # reverse order if the NULL is on the left side
827 binary.left = binary.right
828 binary.right = Null()
829 binary.operator = operators.is_
830 binary.negate = operators.is_not
831 elif (
832 isinstance(binary.right, BindParameter)
833 and binary.right._identifying_key in nulls
834 ):
835 binary.right = Null()
836 binary.operator = operators.is_
837 binary.negate = operators.is_not
838
839 return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
840
841
842def splice_joins(
843 left: Optional[FromClause],
844 right: Optional[FromClause],
845 stop_on: Optional[FromClause] = None,
846) -> Optional[FromClause]:
847 if left is None:
848 return right
849
850 stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)]
851
852 adapter = ClauseAdapter(left)
853 ret = None
854 while stack:
855 (right, prevright) = stack.pop()
856 if isinstance(right, Join) and right is not stop_on:
857 right = right._clone()
858 right.onclause = adapter.traverse(right.onclause)
859 stack.append((right.left, right))
860 else:
861 right = adapter.traverse(right)
862 if prevright is not None:
863 assert right is not None
864 prevright.left = right
865 if ret is None:
866 ret = right
867
868 return ret
869
870
871@overload
872def reduce_columns(
873 columns: Iterable[ColumnElement[Any]],
874 *clauses: Optional[ClauseElement],
875 **kw: bool,
876) -> Sequence[ColumnElement[Any]]: ...
877
878
879@overload
880def reduce_columns(
881 columns: _SelectIterable,
882 *clauses: Optional[ClauseElement],
883 **kw: bool,
884) -> Sequence[Union[ColumnElement[Any], TextClause]]: ...
885
886
887def reduce_columns(
888 columns: _SelectIterable,
889 *clauses: Optional[ClauseElement],
890 **kw: bool,
891) -> Collection[Union[ColumnElement[Any], TextClause]]:
892 r"""given a list of columns, return a 'reduced' set based on natural
893 equivalents.
894
895 the set is reduced to the smallest list of columns which have no natural
896 equivalent present in the list. A "natural equivalent" means that two
897 columns will ultimately represent the same value because they are related
898 by a foreign key.
899
900 \*clauses is an optional list of join clauses which will be traversed
901 to further identify columns that are "equivalent".
902
903 \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
904 whose tables are not yet configured, or columns that aren't yet present.
905
906 This function is primarily used to determine the most minimal "primary
907 key" from a selectable, by reducing the set of primary key columns present
908 in the selectable to just those that are not repeated.
909
910 """
911 ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
912 only_synonyms = kw.pop("only_synonyms", False)
913
914 column_set = util.OrderedSet(columns)
915 cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
916 c for c in column_set if is_text_clause(c) # type: ignore
917 )
918
919 omit = util.column_set()
920 for col in cset_no_text:
921 for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
922 for c in cset_no_text:
923 if c is col:
924 continue
925 try:
926 fk_col = fk.column
927 except exc.NoReferencedColumnError:
928 # TODO: add specific coverage here
929 # to test/sql/test_selectable ReduceTest
930 if ignore_nonexistent_tables:
931 continue
932 else:
933 raise
934 except exc.NoReferencedTableError:
935 # TODO: add specific coverage here
936 # to test/sql/test_selectable ReduceTest
937 if ignore_nonexistent_tables:
938 continue
939 else:
940 raise
941 if fk_col.shares_lineage(c) and (
942 not only_synonyms or c.name == col.name
943 ):
944 omit.add(col)
945 break
946
947 if clauses:
948
949 def visit_binary(binary):
950 if binary.operator == operators.eq:
951 cols = util.column_set(
952 chain(
953 *[c.proxy_set for c in cset_no_text.difference(omit)]
954 )
955 )
956 if binary.left in cols and binary.right in cols:
957 for c in reversed(cset_no_text):
958 if c.shares_lineage(binary.right) and (
959 not only_synonyms or c.name == binary.left.name
960 ):
961 omit.add(c)
962 break
963
964 for clause in clauses:
965 if clause is not None:
966 visitors.traverse(clause, {}, {"binary": visit_binary})
967
968 return column_set.difference(omit)
969
970
971def criterion_as_pairs(
972 expression,
973 consider_as_foreign_keys=None,
974 consider_as_referenced_keys=None,
975 any_operator=False,
976):
977 """traverse an expression and locate binary criterion pairs."""
978
979 if consider_as_foreign_keys and consider_as_referenced_keys:
980 raise exc.ArgumentError(
981 "Can only specify one of "
982 "'consider_as_foreign_keys' or "
983 "'consider_as_referenced_keys'"
984 )
985
986 def col_is(a, b):
987 # return a is b
988 return a.compare(b)
989
990 def visit_binary(binary):
991 if not any_operator and binary.operator is not operators.eq:
992 return
993 if not isinstance(binary.left, ColumnElement) or not isinstance(
994 binary.right, ColumnElement
995 ):
996 return
997
998 if consider_as_foreign_keys:
999 if binary.left in consider_as_foreign_keys and (
1000 col_is(binary.right, binary.left)
1001 or binary.right not in consider_as_foreign_keys
1002 ):
1003 pairs.append((binary.right, binary.left))
1004 elif binary.right in consider_as_foreign_keys and (
1005 col_is(binary.left, binary.right)
1006 or binary.left not in consider_as_foreign_keys
1007 ):
1008 pairs.append((binary.left, binary.right))
1009 elif consider_as_referenced_keys:
1010 if binary.left in consider_as_referenced_keys and (
1011 col_is(binary.right, binary.left)
1012 or binary.right not in consider_as_referenced_keys
1013 ):
1014 pairs.append((binary.left, binary.right))
1015 elif binary.right in consider_as_referenced_keys and (
1016 col_is(binary.left, binary.right)
1017 or binary.left not in consider_as_referenced_keys
1018 ):
1019 pairs.append((binary.right, binary.left))
1020 else:
1021 if isinstance(binary.left, Column) and isinstance(
1022 binary.right, Column
1023 ):
1024 if binary.left.references(binary.right):
1025 pairs.append((binary.right, binary.left))
1026 elif binary.right.references(binary.left):
1027 pairs.append((binary.left, binary.right))
1028
1029 pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = []
1030 visitors.traverse(expression, {}, {"binary": visit_binary})
1031 return pairs
1032
1033
1034class ClauseAdapter(visitors.ReplacingExternalTraversal):
1035 """Clones and modifies clauses based on column correspondence.
1036
1037 E.g.::
1038
1039 table1 = Table(
1040 "sometable",
1041 metadata,
1042 Column("col1", Integer),
1043 Column("col2", Integer),
1044 )
1045 table2 = Table(
1046 "someothertable",
1047 metadata,
1048 Column("col1", Integer),
1049 Column("col2", Integer),
1050 )
1051
1052 condition = table1.c.col1 == table2.c.col1
1053
1054 make an alias of table1::
1055
1056 s = table1.alias("foo")
1057
1058 calling ``ClauseAdapter(s).traverse(condition)`` converts
1059 condition to read::
1060
1061 s.c.col1 == table2.c.col1
1062
1063 """
1064
1065 __slots__ = (
1066 "__traverse_options__",
1067 "selectable",
1068 "include_fn",
1069 "exclude_fn",
1070 "equivalents",
1071 "adapt_on_names",
1072 "adapt_from_selectables",
1073 )
1074
1075 def __init__(
1076 self,
1077 selectable: Selectable,
1078 equivalents: Optional[_EquivalentColumnMap] = None,
1079 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1080 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1081 adapt_on_names: bool = False,
1082 anonymize_labels: bool = False,
1083 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1084 ):
1085 self.__traverse_options__ = {
1086 "stop_on": [selectable],
1087 "anonymize_labels": anonymize_labels,
1088 }
1089 self.selectable = selectable
1090 self.include_fn = include_fn
1091 self.exclude_fn = exclude_fn
1092 self.equivalents = util.column_dict(equivalents or {})
1093 self.adapt_on_names = adapt_on_names
1094 self.adapt_from_selectables = adapt_from_selectables
1095
1096 if TYPE_CHECKING:
1097
1098 @overload
1099 def traverse(self, obj: Literal[None]) -> None: ...
1100
1101 # note this specializes the ReplacingExternalTraversal.traverse()
1102 # method to state
1103 # that we will return the same kind of ExternalTraversal object as
1104 # we were given. This is probably not 100% true, such as it's
1105 # possible for us to swap out Alias for Table at the top level.
1106 # Ideally there could be overloads specific to ColumnElement and
1107 # FromClause but Mypy is not accepting those as compatible with
1108 # the base ReplacingExternalTraversal
1109 @overload
1110 def traverse(self, obj: _ET) -> _ET: ...
1111
1112 def traverse(
1113 self, obj: Optional[ExternallyTraversible]
1114 ) -> Optional[ExternallyTraversible]: ...
1115
1116 def _corresponding_column(
1117 self, col, require_embedded, _seen=util.EMPTY_SET
1118 ):
1119 newcol = self.selectable.corresponding_column(
1120 col, require_embedded=require_embedded
1121 )
1122 if newcol is None and col in self.equivalents and col not in _seen:
1123 for equiv in self.equivalents[col]:
1124 newcol = self._corresponding_column(
1125 equiv,
1126 require_embedded=require_embedded,
1127 _seen=_seen.union([col]),
1128 )
1129 if newcol is not None:
1130 return newcol
1131
1132 if (
1133 self.adapt_on_names
1134 and newcol is None
1135 and isinstance(col, NamedColumn)
1136 ):
1137 newcol = self.selectable.exported_columns.get(col.name)
1138 return newcol
1139
1140 @util.preload_module("sqlalchemy.sql.functions")
1141 def replace(
1142 self, col: _ET, _include_singleton_constants: bool = False
1143 ) -> Optional[_ET]:
1144 functions = util.preloaded.sql_functions
1145
1146 # TODO: cython candidate
1147
1148 if self.include_fn and not self.include_fn(col): # type: ignore
1149 return None
1150 elif self.exclude_fn and self.exclude_fn(col): # type: ignore
1151 return None
1152
1153 if isinstance(col, FromClause) and not isinstance(
1154 col, functions.FunctionElement
1155 ):
1156 if self.selectable.is_derived_from(col):
1157 if self.adapt_from_selectables:
1158 for adp in self.adapt_from_selectables:
1159 if adp.is_derived_from(col):
1160 break
1161 else:
1162 return None
1163 return self.selectable # type: ignore
1164 elif isinstance(col, Alias) and isinstance(
1165 col.element, TableClause
1166 ):
1167 # we are a SELECT statement and not derived from an alias of a
1168 # table (which nonetheless may be a table our SELECT derives
1169 # from), so return the alias to prevent further traversal
1170 # or
1171 # we are an alias of a table and we are not derived from an
1172 # alias of a table (which nonetheless may be the same table
1173 # as ours) so, same thing
1174 return col
1175 else:
1176 # other cases where we are a selectable and the element
1177 # is another join or selectable that contains a table which our
1178 # selectable derives from, that we want to process
1179 return None
1180
1181 elif not isinstance(col, ColumnElement):
1182 return None
1183 elif not _include_singleton_constants and col._is_singleton_constant:
1184 # dont swap out NULL, TRUE, FALSE for a label name
1185 # in a SQL statement that's being rewritten,
1186 # leave them as the constant. This is first noted in #6259,
1187 # however the logic to check this moved here as of #7154 so that
1188 # it is made specific to SQL rewriting and not all column
1189 # correspondence
1190
1191 return None
1192
1193 if "adapt_column" in col._annotations:
1194 col = col._annotations["adapt_column"]
1195
1196 if TYPE_CHECKING:
1197 assert isinstance(col, KeyedColumnElement)
1198
1199 if self.adapt_from_selectables and col not in self.equivalents:
1200 for adp in self.adapt_from_selectables:
1201 if adp.c.corresponding_column(col, False) is not None:
1202 break
1203 else:
1204 return None
1205
1206 if TYPE_CHECKING:
1207 assert isinstance(col, KeyedColumnElement)
1208
1209 return self._corresponding_column( # type: ignore
1210 col, require_embedded=True
1211 )
1212
1213
1214class _ColumnLookup(Protocol):
1215 @overload
1216 def __getitem__(self, key: None) -> None: ...
1217
1218 @overload
1219 def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: ...
1220
1221 @overload
1222 def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: ...
1223
1224 @overload
1225 def __getitem__(self, key: _ET) -> _ET: ...
1226
1227 def __getitem__(self, key: Any) -> Any: ...
1228
1229
1230class ColumnAdapter(ClauseAdapter):
1231 """Extends ClauseAdapter with extra utility functions.
1232
1233 Key aspects of ColumnAdapter include:
1234
1235 * Expressions that are adapted are stored in a persistent
1236 .columns collection; so that an expression E adapted into
1237 an expression E1, will return the same object E1 when adapted
1238 a second time. This is important in particular for things like
1239 Label objects that are anonymized, so that the ColumnAdapter can
1240 be used to present a consistent "adapted" view of things.
1241
1242 * Exclusion of items from the persistent collection based on
1243 include/exclude rules, but also independent of hash identity.
1244 This because "annotated" items all have the same hash identity as their
1245 parent.
1246
1247 * "wrapping" capability is added, so that the replacement of an expression
1248 E can proceed through a series of adapters. This differs from the
1249 visitor's "chaining" feature in that the resulting object is passed
1250 through all replacing functions unconditionally, rather than stopping
1251 at the first one that returns non-None.
1252
1253 * An adapt_required option, used by eager loading to indicate that
1254 We don't trust a result row column that is not translated.
1255 This is to prevent a column from being interpreted as that
1256 of the child row in a self-referential scenario, see
1257 inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
1258
1259 """
1260
1261 __slots__ = (
1262 "columns",
1263 "adapt_required",
1264 "allow_label_resolve",
1265 "_wrap",
1266 "__weakref__",
1267 )
1268
1269 columns: _ColumnLookup
1270
1271 def __init__(
1272 self,
1273 selectable: Selectable,
1274 equivalents: Optional[_EquivalentColumnMap] = None,
1275 adapt_required: bool = False,
1276 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1277 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1278 adapt_on_names: bool = False,
1279 allow_label_resolve: bool = True,
1280 anonymize_labels: bool = False,
1281 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1282 ):
1283 super().__init__(
1284 selectable,
1285 equivalents,
1286 include_fn=include_fn,
1287 exclude_fn=exclude_fn,
1288 adapt_on_names=adapt_on_names,
1289 anonymize_labels=anonymize_labels,
1290 adapt_from_selectables=adapt_from_selectables,
1291 )
1292
1293 self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
1294 if self.include_fn or self.exclude_fn:
1295 self.columns = self._IncludeExcludeMapping(self, self.columns)
1296 self.adapt_required = adapt_required
1297 self.allow_label_resolve = allow_label_resolve
1298 self._wrap = None
1299
1300 class _IncludeExcludeMapping:
1301 def __init__(self, parent, columns):
1302 self.parent = parent
1303 self.columns = columns
1304
1305 def __getitem__(self, key):
1306 if (
1307 self.parent.include_fn and not self.parent.include_fn(key)
1308 ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
1309 if self.parent._wrap:
1310 return self.parent._wrap.columns[key]
1311 else:
1312 return key
1313 return self.columns[key]
1314
1315 def wrap(self, adapter):
1316 ac = copy.copy(self)
1317 ac._wrap = adapter
1318 ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore
1319 if ac.include_fn or ac.exclude_fn:
1320 ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
1321
1322 return ac
1323
1324 @overload
1325 def traverse(self, obj: Literal[None]) -> None: ...
1326
1327 @overload
1328 def traverse(self, obj: _ET) -> _ET: ...
1329
1330 def traverse(
1331 self, obj: Optional[ExternallyTraversible]
1332 ) -> Optional[ExternallyTraversible]:
1333 return self.columns[obj]
1334
1335 def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
1336 assert isinstance(visitor, ColumnAdapter)
1337
1338 return super().chain(visitor)
1339
1340 if TYPE_CHECKING:
1341
1342 @property
1343 def visitor_iterator(self) -> Iterator[ColumnAdapter]: ...
1344
1345 adapt_clause = traverse
1346 adapt_list = ClauseAdapter.copy_and_process
1347
1348 def adapt_check_present(
1349 self, col: ColumnElement[Any]
1350 ) -> Optional[ColumnElement[Any]]:
1351 newcol = self.columns[col]
1352
1353 if newcol is col and self._corresponding_column(col, True) is None:
1354 return None
1355
1356 return newcol
1357
1358 def _locate_col(
1359 self, col: ColumnElement[Any]
1360 ) -> Optional[ColumnElement[Any]]:
1361 # both replace and traverse() are overly complicated for what
1362 # we are doing here and we would do better to have an inlined
1363 # version that doesn't build up as much overhead. the issue is that
1364 # sometimes the lookup does in fact have to adapt the insides of
1365 # say a labeled scalar subquery. However, if the object is an
1366 # Immutable, i.e. Column objects, we can skip the "clone" /
1367 # "copy internals" part since those will be no-ops in any case.
1368 # additionally we want to catch singleton objects null/true/false
1369 # and make sure they are adapted as well here.
1370
1371 if col._is_immutable:
1372 for vis in self.visitor_iterator:
1373 c = vis.replace(col, _include_singleton_constants=True)
1374 if c is not None:
1375 break
1376 else:
1377 c = col
1378 else:
1379 c = ClauseAdapter.traverse(self, col)
1380
1381 if self._wrap:
1382 c2 = self._wrap._locate_col(c)
1383 if c2 is not None:
1384 c = c2
1385
1386 if self.adapt_required and c is col:
1387 return None
1388
1389 # allow_label_resolve is consumed by one case for joined eager loading
1390 # as part of its logic to prevent its own columns from being affected
1391 # by .order_by(). Before full typing were applied to the ORM, this
1392 # logic would set this attribute on the incoming object (which is
1393 # typically a column, but we have a test for it being a non-column
1394 # object) if no column were found. While this seemed to
1395 # have no negative effects, this adjustment should only occur on the
1396 # new column which is assumed to be local to an adapted selectable.
1397 if c is not col:
1398 c._allow_label_resolve = self.allow_label_resolve
1399
1400 return c
1401
1402
1403def _offset_or_limit_clause(
1404 element: _LimitOffsetType,
1405 name: Optional[str] = None,
1406 type_: Optional[_TypeEngineArgument[int]] = None,
1407) -> ColumnElement[int]:
1408 """Convert the given value to an "offset or limit" clause.
1409
1410 This handles incoming integers and converts to an expression; if
1411 an expression is already given, it is passed through.
1412
1413 """
1414 return coercions.expect(
1415 roles.LimitOffsetRole, element, name=name, type_=type_
1416 )
1417
1418
1419def _offset_or_limit_clause_asint_if_possible(
1420 clause: _LimitOffsetType,
1421) -> _LimitOffsetType:
1422 """Return the offset or limit clause as a simple integer if possible,
1423 else return the clause.
1424
1425 """
1426 if clause is None:
1427 return None
1428 if hasattr(clause, "_limit_offset_value"):
1429 value = clause._limit_offset_value
1430 return util.asint(value)
1431 else:
1432 return clause
1433
1434
1435def _make_slice(
1436 limit_clause: _LimitOffsetType,
1437 offset_clause: _LimitOffsetType,
1438 start: int,
1439 stop: int,
1440) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]:
1441 """Compute LIMIT/OFFSET in terms of slice start/end"""
1442
1443 # for calculated limit/offset, try to do the addition of
1444 # values to offset in Python, however if a SQL clause is present
1445 # then the addition has to be on the SQL side.
1446
1447 # TODO: typing is finding a few gaps in here, see if they can be
1448 # closed up
1449
1450 if start is not None and stop is not None:
1451 offset_clause = _offset_or_limit_clause_asint_if_possible(
1452 offset_clause
1453 )
1454 if offset_clause is None:
1455 offset_clause = 0
1456
1457 if start != 0:
1458 offset_clause = offset_clause + start # type: ignore
1459
1460 if offset_clause == 0:
1461 offset_clause = None
1462 else:
1463 assert offset_clause is not None
1464 offset_clause = _offset_or_limit_clause(offset_clause)
1465
1466 limit_clause = _offset_or_limit_clause(stop - start)
1467
1468 elif start is None and stop is not None:
1469 limit_clause = _offset_or_limit_clause(stop)
1470 elif start is not None and stop is None:
1471 offset_clause = _offset_or_limit_clause_asint_if_possible(
1472 offset_clause
1473 )
1474 if offset_clause is None:
1475 offset_clause = 0
1476
1477 if start != 0:
1478 offset_clause = offset_clause + start
1479
1480 if offset_clause == 0:
1481 offset_clause = None
1482 else:
1483 offset_clause = _offset_or_limit_clause(offset_clause)
1484
1485 return limit_clause, offset_clause