1# sql/util.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""High level utilities which build upon other modules here."""
10from __future__ import annotations
11
12from collections import deque
13import copy
14from itertools import chain
15import typing
16from typing import AbstractSet
17from typing import Any
18from typing import Callable
19from typing import cast
20from typing import Collection
21from typing import Dict
22from typing import Iterable
23from typing import Iterator
24from typing import List
25from typing import Optional
26from typing import overload
27from typing import Protocol
28from typing import Sequence
29from typing import Tuple
30from typing import TYPE_CHECKING
31from typing import TypeVar
32from typing import Union
33
34from . import coercions
35from . import operators
36from . import roles
37from . import visitors
38from ._typing import is_text_clause
39from .annotation import _deep_annotate as _deep_annotate # noqa: F401
40from .annotation import _deep_deannotate as _deep_deannotate # noqa: F401
41from .annotation import _shallow_annotate as _shallow_annotate # noqa: F401
42from .base import _expand_cloned
43from .base import _from_objects
44from .cache_key import HasCacheKey as HasCacheKey # noqa: F401
45from .ddl import sort_tables as sort_tables # noqa: F401
46from .elements import _find_columns as _find_columns
47from .elements import _label_reference
48from .elements import _textual_label_reference
49from .elements import BindParameter
50from .elements import ClauseElement
51from .elements import ColumnClause
52from .elements import ColumnElement
53from .elements import Grouping
54from .elements import KeyedColumnElement
55from .elements import Label
56from .elements import NamedColumn
57from .elements import Null
58from .elements import UnaryExpression
59from .schema import Column
60from .selectable import Alias
61from .selectable import FromClause
62from .selectable import FromGrouping
63from .selectable import Join
64from .selectable import ScalarSelect
65from .selectable import SelectBase
66from .selectable import TableClause
67from .visitors import _ET
68from .. import exc
69from .. import util
70from ..util.typing import Literal
71from ..util.typing import Unpack
72
73if typing.TYPE_CHECKING:
74 from ._typing import _EquivalentColumnMap
75 from ._typing import _LimitOffsetType
76 from ._typing import _TypeEngineArgument
77 from .elements import BinaryExpression
78 from .elements import TextClause
79 from .selectable import _JoinTargetElement
80 from .selectable import _SelectIterable
81 from .selectable import Selectable
82 from .visitors import _TraverseCallableType
83 from .visitors import ExternallyTraversible
84 from .visitors import ExternalTraversal
85 from ..engine.interfaces import _AnyExecuteParams
86 from ..engine.interfaces import _AnyMultiExecuteParams
87 from ..engine.interfaces import _AnySingleExecuteParams
88 from ..engine.interfaces import _CoreSingleExecuteParams
89 from ..engine.row import Row
90
91_CE = TypeVar("_CE", bound="ColumnElement[Any]")
92
93
94def join_condition(
95 a: FromClause,
96 b: FromClause,
97 a_subset: Optional[FromClause] = None,
98 consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
99) -> ColumnElement[bool]:
100 """Create a join condition between two tables or selectables.
101
102 e.g.::
103
104 join_condition(tablea, tableb)
105
106 would produce an expression along the lines of::
107
108 tablea.c.id == tableb.c.tablea_id
109
110 The join is determined based on the foreign key relationships
111 between the two selectables. If there are multiple ways
112 to join, or no way to join, an error is raised.
113
114 :param a_subset: An optional expression that is a sub-component
115 of ``a``. An attempt will be made to join to just this sub-component
116 first before looking at the full ``a`` construct, and if found
117 will be successful even if there are other ways to join to ``a``.
118 This allows the "right side" of a join to be passed thereby
119 providing a "natural join".
120
121 """
122 return Join._join_condition(
123 a,
124 b,
125 a_subset=a_subset,
126 consider_as_foreign_keys=consider_as_foreign_keys,
127 )
128
129
130def find_join_source(
131 clauses: List[FromClause], join_to: FromClause
132) -> List[int]:
133 """Given a list of FROM clauses and a selectable,
134 return the first index and element from the list of
135 clauses which can be joined against the selectable. returns
136 None, None if no match is found.
137
138 e.g.::
139
140 clause1 = table1.join(table2)
141 clause2 = table4.join(table5)
142
143 join_to = table2.join(table3)
144
145 find_join_source([clause1, clause2], join_to) == clause1
146
147 """
148
149 selectables = list(_from_objects(join_to))
150 idx = []
151 for i, f in enumerate(clauses):
152 for s in selectables:
153 if f.is_derived_from(s):
154 idx.append(i)
155 return idx
156
157
158def find_left_clause_that_matches_given(
159 clauses: Sequence[FromClause], join_from: FromClause
160) -> List[int]:
161 """Given a list of FROM clauses and a selectable,
162 return the indexes from the list of
163 clauses which is derived from the selectable.
164
165 """
166
167 selectables = list(_from_objects(join_from))
168 liberal_idx = []
169 for i, f in enumerate(clauses):
170 for s in selectables:
171 # basic check, if f is derived from s.
172 # this can be joins containing a table, or an aliased table
173 # or select statement matching to a table. This check
174 # will match a table to a selectable that is adapted from
175 # that table. With Query, this suits the case where a join
176 # is being made to an adapted entity
177 if f.is_derived_from(s):
178 liberal_idx.append(i)
179 break
180
181 # in an extremely small set of use cases, a join is being made where
182 # there are multiple FROM clauses where our target table is represented
183 # in more than one, such as embedded or similar. in this case, do
184 # another pass where we try to get a more exact match where we aren't
185 # looking at adaption relationships.
186 if len(liberal_idx) > 1:
187 conservative_idx = []
188 for idx in liberal_idx:
189 f = clauses[idx]
190 for s in selectables:
191 if set(surface_selectables(f)).intersection(
192 surface_selectables(s)
193 ):
194 conservative_idx.append(idx)
195 break
196 if conservative_idx:
197 return conservative_idx
198
199 return liberal_idx
200
201
202def find_left_clause_to_join_from(
203 clauses: Sequence[FromClause],
204 join_to: _JoinTargetElement,
205 onclause: Optional[ColumnElement[Any]],
206) -> List[int]:
207 """Given a list of FROM clauses, a selectable,
208 and optional ON clause, return a list of integer indexes from the
209 clauses list indicating the clauses that can be joined from.
210
211 The presence of an "onclause" indicates that at least one clause can
212 definitely be joined from; if the list of clauses is of length one
213 and the onclause is given, returns that index. If the list of clauses
214 is more than length one, and the onclause is given, attempts to locate
215 which clauses contain the same columns.
216
217 """
218 idx = []
219 selectables = set(_from_objects(join_to))
220
221 # if we are given more than one target clause to join
222 # from, use the onclause to provide a more specific answer.
223 # otherwise, don't try to limit, after all, "ON TRUE" is a valid
224 # on clause
225 if len(clauses) > 1 and onclause is not None:
226 resolve_ambiguity = True
227 cols_in_onclause = _find_columns(onclause)
228 else:
229 resolve_ambiguity = False
230 cols_in_onclause = None
231
232 for i, f in enumerate(clauses):
233 for s in selectables.difference([f]):
234 if resolve_ambiguity:
235 assert cols_in_onclause is not None
236 if set(f.c).union(s.c).issuperset(cols_in_onclause):
237 idx.append(i)
238 break
239 elif onclause is not None or Join._can_join(f, s):
240 idx.append(i)
241 break
242
243 if len(idx) > 1:
244 # this is the same "hide froms" logic from
245 # Selectable._get_display_froms
246 toremove = set(
247 chain(*[_expand_cloned(f._hide_froms) for f in clauses])
248 )
249 idx = [i for i in idx if clauses[i] not in toremove]
250
251 # onclause was given and none of them resolved, so assume
252 # all indexes can match
253 if not idx and onclause is not None:
254 return list(range(len(clauses)))
255 else:
256 return idx
257
258
259def visit_binary_product(
260 fn: Callable[
261 [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
262 ],
263 expr: ColumnElement[Any],
264) -> None:
265 """Produce a traversal of the given expression, delivering
266 column comparisons to the given function.
267
268 The function is of the form::
269
270 def my_fn(binary, left, right): ...
271
272 For each binary expression located which has a
273 comparison operator, the product of "left" and
274 "right" will be delivered to that function,
275 in terms of that binary.
276
277 Hence an expression like::
278
279 and_((a + b) == q + func.sum(e + f), j == r)
280
281 would have the traversal:
282
283 .. sourcecode:: text
284
285 a <eq> q
286 a <eq> e
287 a <eq> f
288 b <eq> q
289 b <eq> e
290 b <eq> f
291 j <eq> r
292
293 That is, every combination of "left" and
294 "right" that doesn't further contain
295 a binary comparison is passed as pairs.
296
297 """
298 stack: List[BinaryExpression[Any]] = []
299
300 def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
301 if isinstance(element, ScalarSelect):
302 # we don't want to dig into correlated subqueries,
303 # those are just column elements by themselves
304 yield element
305 elif element.__visit_name__ == "binary" and operators.is_comparison(
306 element.operator # type: ignore
307 ):
308 stack.insert(0, element) # type: ignore
309 for l in visit(element.left): # type: ignore
310 for r in visit(element.right): # type: ignore
311 fn(stack[0], l, r)
312 stack.pop(0)
313 for elem in element.get_children():
314 visit(elem)
315 else:
316 if isinstance(element, ColumnClause):
317 yield element
318 for elem in element.get_children():
319 yield from visit(elem)
320
321 list(visit(expr))
322 visit = None # type: ignore # remove gc cycles
323
324
325def find_tables(
326 clause: ClauseElement,
327 *,
328 check_columns: bool = False,
329 include_aliases: bool = False,
330 include_joins: bool = False,
331 include_selects: bool = False,
332 include_crud: bool = False,
333) -> List[TableClause]:
334 """locate Table objects within the given expression."""
335
336 tables: List[TableClause] = []
337 _visitors: Dict[str, _TraverseCallableType[Any]] = {}
338
339 if include_selects:
340 _visitors["select"] = _visitors["compound_select"] = tables.append
341
342 if include_joins:
343 _visitors["join"] = tables.append
344
345 if include_aliases:
346 _visitors["alias"] = _visitors["subquery"] = _visitors[
347 "tablesample"
348 ] = _visitors["lateral"] = tables.append
349
350 if include_crud:
351 _visitors["insert"] = _visitors["update"] = _visitors["delete"] = (
352 lambda ent: tables.append(ent.table)
353 )
354
355 if check_columns:
356
357 def visit_column(column):
358 tables.append(column.table)
359
360 _visitors["column"] = visit_column
361
362 _visitors["table"] = tables.append
363
364 visitors.traverse(clause, {}, _visitors)
365 return tables
366
367
368def unwrap_order_by(clause: Any) -> Any:
369 """Break up an 'order by' expression into individual column-expressions,
370 without DESC/ASC/NULLS FIRST/NULLS LAST"""
371
372 cols = util.column_set()
373 result = []
374 stack = deque([clause])
375
376 # examples
377 # column -> ASC/DESC == column
378 # column -> ASC/DESC -> label == column
379 # column -> label -> ASC/DESC -> label == column
380 # scalar_select -> label -> ASC/DESC == scalar_select -> label
381
382 while stack:
383 t = stack.popleft()
384 if isinstance(t, ColumnElement) and (
385 not isinstance(t, UnaryExpression)
386 or not operators.is_ordering_modifier(t.modifier) # type: ignore
387 ):
388 if isinstance(t, Label) and not isinstance(
389 t.element, ScalarSelect
390 ):
391 t = t.element
392
393 if isinstance(t, Grouping):
394 t = t.element
395
396 stack.append(t)
397 continue
398 elif isinstance(t, _label_reference):
399 t = t.element
400
401 stack.append(t)
402 continue
403 if isinstance(t, (_textual_label_reference)):
404 continue
405 if t not in cols:
406 cols.add(t)
407 result.append(t)
408
409 else:
410 for c in t.get_children():
411 stack.append(c)
412 return result
413
414
415def unwrap_label_reference(element):
416 def replace(
417 element: ExternallyTraversible, **kw: Any
418 ) -> Optional[ExternallyTraversible]:
419 if isinstance(element, _label_reference):
420 return element.element
421 elif isinstance(element, _textual_label_reference):
422 assert False, "can't unwrap a textual label reference"
423 return None
424
425 return visitors.replacement_traverse(element, {}, replace)
426
427
428def expand_column_list_from_order_by(collist, order_by):
429 """Given the columns clause and ORDER BY of a selectable,
430 return a list of column expressions that can be added to the collist
431 corresponding to the ORDER BY, without repeating those already
432 in the collist.
433
434 """
435 cols_already_present = {
436 col.element if col._order_by_label_element is not None else col
437 for col in collist
438 }
439
440 to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
441
442 return [col for col in to_look_for if col not in cols_already_present]
443
444
445def clause_is_present(clause, search):
446 """Given a target clause and a second to search within, return True
447 if the target is plainly present in the search without any
448 subqueries or aliases involved.
449
450 Basically descends through Joins.
451
452 """
453
454 for elem in surface_selectables(search):
455 if clause == elem: # use == here so that Annotated's compare
456 return True
457 else:
458 return False
459
460
461def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
462 if isinstance(clause, Join):
463 yield from tables_from_leftmost(clause.left)
464 yield from tables_from_leftmost(clause.right)
465 elif isinstance(clause, FromGrouping):
466 yield from tables_from_leftmost(clause.element)
467 else:
468 yield clause
469
470
471def surface_selectables(clause):
472 stack = [clause]
473 while stack:
474 elem = stack.pop()
475 yield elem
476 if isinstance(elem, Join):
477 stack.extend((elem.left, elem.right))
478 elif isinstance(elem, FromGrouping):
479 stack.append(elem.element)
480
481
482def surface_selectables_only(clause: ClauseElement) -> Iterator[ClauseElement]:
483 stack = [clause]
484 while stack:
485 elem = stack.pop()
486 if isinstance(elem, (TableClause, Alias)):
487 yield elem
488 if isinstance(elem, Join):
489 stack.extend((elem.left, elem.right))
490 elif isinstance(elem, FromGrouping):
491 stack.append(elem.element)
492 elif isinstance(elem, ColumnClause):
493 if elem.table is not None:
494 stack.append(elem.table)
495 else:
496 yield elem
497 elif elem is not None:
498 yield elem
499
500
501def extract_first_column_annotation(column, annotation_name):
502 filter_ = (FromGrouping, SelectBase)
503
504 stack = deque([column])
505 while stack:
506 elem = stack.popleft()
507 if annotation_name in elem._annotations:
508 return elem._annotations[annotation_name]
509 for sub in elem.get_children():
510 if isinstance(sub, filter_):
511 continue
512 stack.append(sub)
513 return None
514
515
516def selectables_overlap(left: FromClause, right: FromClause) -> bool:
517 """Return True if left/right have some overlapping selectable"""
518
519 return bool(
520 set(surface_selectables(left)).intersection(surface_selectables(right))
521 )
522
523
524def bind_values(clause):
525 """Return an ordered list of "bound" values in the given clause.
526
527 E.g.::
528
529 >>> expr = and_(table.c.foo == 5, table.c.foo == 7)
530 >>> bind_values(expr)
531 [5, 7]
532 """
533
534 v = []
535
536 def visit_bindparam(bind):
537 v.append(bind.effective_value)
538
539 visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
540 return v
541
542
543def _quote_ddl_expr(element):
544 if isinstance(element, str):
545 element = element.replace("'", "''")
546 return "'%s'" % element
547 else:
548 return repr(element)
549
550
551class _repr_base:
552 _LIST: int = 0
553 _TUPLE: int = 1
554 _DICT: int = 2
555
556 __slots__ = ("max_chars",)
557
558 max_chars: int
559
560 def trunc(self, value: Any) -> str:
561 rep = repr(value)
562 lenrep = len(rep)
563 if lenrep > self.max_chars:
564 segment_length = self.max_chars // 2
565 rep = (
566 rep[0:segment_length]
567 + (
568 " ... (%d characters truncated) ... "
569 % (lenrep - self.max_chars)
570 )
571 + rep[-segment_length:]
572 )
573 return rep
574
575
576def _repr_single_value(value):
577 rp = _repr_base()
578 rp.max_chars = 300
579 return rp.trunc(value)
580
581
582class _repr_row(_repr_base):
583 """Provide a string view of a row."""
584
585 __slots__ = ("row",)
586
587 def __init__(
588 self, row: Row[Unpack[Tuple[Any, ...]]], max_chars: int = 300
589 ):
590 self.row = row
591 self.max_chars = max_chars
592
593 def __repr__(self) -> str:
594 trunc = self.trunc
595 return "(%s%s)" % (
596 ", ".join(trunc(value) for value in self.row),
597 "," if len(self.row) == 1 else "",
598 )
599
600
601class _long_statement(str):
602 def __str__(self) -> str:
603 lself = len(self)
604 if lself > 500:
605 lleft = 250
606 lright = 100
607 trunc = lself - lleft - lright
608 return (
609 f"{self[0:lleft]} ... {trunc} "
610 f"characters truncated ... {self[-lright:]}"
611 )
612 else:
613 return str.__str__(self)
614
615
616class _repr_params(_repr_base):
617 """Provide a string view of bound parameters.
618
619 Truncates display to a given number of 'multi' parameter sets,
620 as well as long values to a given number of characters.
621
622 """
623
624 __slots__ = "params", "batches", "ismulti", "max_params"
625
626 def __init__(
627 self,
628 params: Optional[_AnyExecuteParams],
629 batches: int,
630 max_params: int = 100,
631 max_chars: int = 300,
632 ismulti: Optional[bool] = None,
633 ):
634 self.params = params
635 self.ismulti = ismulti
636 self.batches = batches
637 self.max_chars = max_chars
638 self.max_params = max_params
639
640 def __repr__(self) -> str:
641 if self.ismulti is None:
642 return self.trunc(self.params)
643
644 if isinstance(self.params, list):
645 typ = self._LIST
646
647 elif isinstance(self.params, tuple):
648 typ = self._TUPLE
649 elif isinstance(self.params, dict):
650 typ = self._DICT
651 else:
652 return self.trunc(self.params)
653
654 if self.ismulti:
655 multi_params = cast(
656 "_AnyMultiExecuteParams",
657 self.params,
658 )
659
660 if len(self.params) > self.batches:
661 msg = (
662 " ... displaying %i of %i total bound parameter sets ... "
663 )
664 return " ".join(
665 (
666 self._repr_multi(
667 multi_params[: self.batches - 2],
668 typ,
669 )[0:-1],
670 msg % (self.batches, len(self.params)),
671 self._repr_multi(multi_params[-2:], typ)[1:],
672 )
673 )
674 else:
675 return self._repr_multi(multi_params, typ)
676 else:
677 return self._repr_params(
678 cast(
679 "_AnySingleExecuteParams",
680 self.params,
681 ),
682 typ,
683 )
684
685 def _repr_multi(
686 self,
687 multi_params: _AnyMultiExecuteParams,
688 typ: int,
689 ) -> str:
690 if multi_params:
691 if isinstance(multi_params[0], list):
692 elem_type = self._LIST
693 elif isinstance(multi_params[0], tuple):
694 elem_type = self._TUPLE
695 elif isinstance(multi_params[0], dict):
696 elem_type = self._DICT
697 else:
698 assert False, "Unknown parameter type %s" % (
699 type(multi_params[0])
700 )
701
702 elements = ", ".join(
703 self._repr_params(params, elem_type) for params in multi_params
704 )
705 else:
706 elements = ""
707
708 if typ == self._LIST:
709 return "[%s]" % elements
710 else:
711 return "(%s)" % elements
712
713 def _get_batches(self, params: Iterable[Any]) -> Any:
714 lparams = list(params)
715 lenparams = len(lparams)
716 if lenparams > self.max_params:
717 lleft = self.max_params // 2
718 return (
719 lparams[0:lleft],
720 lparams[-lleft:],
721 lenparams - self.max_params,
722 )
723 else:
724 return lparams, None, None
725
726 def _repr_params(
727 self,
728 params: _AnySingleExecuteParams,
729 typ: int,
730 ) -> str:
731 if typ is self._DICT:
732 return self._repr_param_dict(
733 cast("_CoreSingleExecuteParams", params)
734 )
735 elif typ is self._TUPLE:
736 return self._repr_param_tuple(cast("Sequence[Any]", params))
737 else:
738 return self._repr_param_list(params)
739
740 def _repr_param_dict(self, params: _CoreSingleExecuteParams) -> str:
741 trunc = self.trunc
742 (
743 items_first_batch,
744 items_second_batch,
745 trunclen,
746 ) = self._get_batches(params.items())
747
748 if items_second_batch:
749 text = "{%s" % (
750 ", ".join(
751 f"{key!r}: {trunc(value)}"
752 for key, value in items_first_batch
753 )
754 )
755 text += f" ... {trunclen} parameters truncated ... "
756 text += "%s}" % (
757 ", ".join(
758 f"{key!r}: {trunc(value)}"
759 for key, value in items_second_batch
760 )
761 )
762 else:
763 text = "{%s}" % (
764 ", ".join(
765 f"{key!r}: {trunc(value)}"
766 for key, value in items_first_batch
767 )
768 )
769 return text
770
771 def _repr_param_tuple(self, params: Sequence[Any]) -> str:
772 trunc = self.trunc
773
774 (
775 items_first_batch,
776 items_second_batch,
777 trunclen,
778 ) = self._get_batches(params)
779
780 if items_second_batch:
781 text = "(%s" % (
782 ", ".join(trunc(value) for value in items_first_batch)
783 )
784 text += f" ... {trunclen} parameters truncated ... "
785 text += "%s)" % (
786 ", ".join(trunc(value) for value in items_second_batch),
787 )
788 else:
789 text = "(%s%s)" % (
790 ", ".join(trunc(value) for value in items_first_batch),
791 "," if len(items_first_batch) == 1 else "",
792 )
793 return text
794
795 def _repr_param_list(self, params: _AnySingleExecuteParams) -> str:
796 trunc = self.trunc
797 (
798 items_first_batch,
799 items_second_batch,
800 trunclen,
801 ) = self._get_batches(params)
802
803 if items_second_batch:
804 text = "[%s" % (
805 ", ".join(trunc(value) for value in items_first_batch)
806 )
807 text += f" ... {trunclen} parameters truncated ... "
808 text += "%s]" % (
809 ", ".join(trunc(value) for value in items_second_batch)
810 )
811 else:
812 text = "[%s]" % (
813 ", ".join(trunc(value) for value in items_first_batch)
814 )
815 return text
816
817
818def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
819 """given criterion containing bind params, convert selected elements
820 to IS NULL.
821
822 """
823
824 def visit_binary(binary):
825 if (
826 isinstance(binary.left, BindParameter)
827 and binary.left._identifying_key in nulls
828 ):
829 # reverse order if the NULL is on the left side
830 binary.left = binary.right
831 binary.right = Null()
832 binary.operator = operators.is_
833 binary.negate = operators.is_not
834 elif (
835 isinstance(binary.right, BindParameter)
836 and binary.right._identifying_key in nulls
837 ):
838 binary.right = Null()
839 binary.operator = operators.is_
840 binary.negate = operators.is_not
841
842 return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
843
844
845def splice_joins(
846 left: Optional[FromClause],
847 right: Optional[FromClause],
848 stop_on: Optional[FromClause] = None,
849) -> Optional[FromClause]:
850 if left is None:
851 return right
852
853 stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)]
854
855 adapter = ClauseAdapter(left)
856 ret = None
857 while stack:
858 (right, prevright) = stack.pop()
859 if isinstance(right, Join) and right is not stop_on:
860 right = right._clone()
861 right.onclause = adapter.traverse(right.onclause)
862 stack.append((right.left, right))
863 else:
864 right = adapter.traverse(right)
865 if prevright is not None:
866 assert right is not None
867 prevright.left = right
868 if ret is None:
869 ret = right
870
871 return ret
872
873
874@overload
875def reduce_columns(
876 columns: Iterable[ColumnElement[Any]],
877 *clauses: Optional[ClauseElement],
878 **kw: bool,
879) -> Sequence[ColumnElement[Any]]: ...
880
881
882@overload
883def reduce_columns(
884 columns: _SelectIterable,
885 *clauses: Optional[ClauseElement],
886 **kw: bool,
887) -> Sequence[Union[ColumnElement[Any], TextClause]]: ...
888
889
890def reduce_columns(
891 columns: _SelectIterable,
892 *clauses: Optional[ClauseElement],
893 **kw: bool,
894) -> Collection[Union[ColumnElement[Any], TextClause]]:
895 r"""given a list of columns, return a 'reduced' set based on natural
896 equivalents.
897
898 the set is reduced to the smallest list of columns which have no natural
899 equivalent present in the list. A "natural equivalent" means that two
900 columns will ultimately represent the same value because they are related
901 by a foreign key.
902
903 \*clauses is an optional list of join clauses which will be traversed
904 to further identify columns that are "equivalent".
905
906 \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
907 whose tables are not yet configured, or columns that aren't yet present.
908
909 This function is primarily used to determine the most minimal "primary
910 key" from a selectable, by reducing the set of primary key columns present
911 in the selectable to just those that are not repeated.
912
913 """
914 ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
915 only_synonyms = kw.pop("only_synonyms", False)
916
917 column_set = util.OrderedSet(columns)
918 cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
919 c for c in column_set if is_text_clause(c) # type: ignore
920 )
921
922 omit = util.column_set()
923 for col in cset_no_text:
924 for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
925 for c in cset_no_text:
926 if c is col:
927 continue
928 try:
929 fk_col = fk.column
930 except exc.NoReferencedColumnError:
931 # TODO: add specific coverage here
932 # to test/sql/test_selectable ReduceTest
933 if ignore_nonexistent_tables:
934 continue
935 else:
936 raise
937 except exc.NoReferencedTableError:
938 # TODO: add specific coverage here
939 # to test/sql/test_selectable ReduceTest
940 if ignore_nonexistent_tables:
941 continue
942 else:
943 raise
944 if fk_col.shares_lineage(c) and (
945 not only_synonyms or c.name == col.name
946 ):
947 omit.add(col)
948 break
949
950 if clauses:
951
952 def visit_binary(binary):
953 if binary.operator == operators.eq:
954 cols = util.column_set(
955 chain(
956 *[c.proxy_set for c in cset_no_text.difference(omit)]
957 )
958 )
959 if binary.left in cols and binary.right in cols:
960 for c in reversed(cset_no_text):
961 if c.shares_lineage(binary.right) and (
962 not only_synonyms or c.name == binary.left.name
963 ):
964 omit.add(c)
965 break
966
967 for clause in clauses:
968 if clause is not None:
969 visitors.traverse(clause, {}, {"binary": visit_binary})
970
971 return column_set.difference(omit)
972
973
974def criterion_as_pairs(
975 expression,
976 consider_as_foreign_keys=None,
977 consider_as_referenced_keys=None,
978 any_operator=False,
979):
980 """traverse an expression and locate binary criterion pairs."""
981
982 if consider_as_foreign_keys and consider_as_referenced_keys:
983 raise exc.ArgumentError(
984 "Can only specify one of "
985 "'consider_as_foreign_keys' or "
986 "'consider_as_referenced_keys'"
987 )
988
989 def col_is(a, b):
990 # return a is b
991 return a.compare(b)
992
993 def visit_binary(binary):
994 if not any_operator and binary.operator is not operators.eq:
995 return
996 if not isinstance(binary.left, ColumnElement) or not isinstance(
997 binary.right, ColumnElement
998 ):
999 return
1000
1001 if consider_as_foreign_keys:
1002 if binary.left in consider_as_foreign_keys and (
1003 col_is(binary.right, binary.left)
1004 or binary.right not in consider_as_foreign_keys
1005 ):
1006 pairs.append((binary.right, binary.left))
1007 elif binary.right in consider_as_foreign_keys and (
1008 col_is(binary.left, binary.right)
1009 or binary.left not in consider_as_foreign_keys
1010 ):
1011 pairs.append((binary.left, binary.right))
1012 elif consider_as_referenced_keys:
1013 if binary.left in consider_as_referenced_keys and (
1014 col_is(binary.right, binary.left)
1015 or binary.right not in consider_as_referenced_keys
1016 ):
1017 pairs.append((binary.left, binary.right))
1018 elif binary.right in consider_as_referenced_keys and (
1019 col_is(binary.left, binary.right)
1020 or binary.left not in consider_as_referenced_keys
1021 ):
1022 pairs.append((binary.right, binary.left))
1023 else:
1024 if isinstance(binary.left, Column) and isinstance(
1025 binary.right, Column
1026 ):
1027 if binary.left.references(binary.right):
1028 pairs.append((binary.right, binary.left))
1029 elif binary.right.references(binary.left):
1030 pairs.append((binary.left, binary.right))
1031
1032 pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = []
1033 visitors.traverse(expression, {}, {"binary": visit_binary})
1034 return pairs
1035
1036
1037class ClauseAdapter(visitors.ReplacingExternalTraversal):
1038 """Clones and modifies clauses based on column correspondence.
1039
1040 E.g.::
1041
1042 table1 = Table(
1043 "sometable",
1044 metadata,
1045 Column("col1", Integer),
1046 Column("col2", Integer),
1047 )
1048 table2 = Table(
1049 "someothertable",
1050 metadata,
1051 Column("col1", Integer),
1052 Column("col2", Integer),
1053 )
1054
1055 condition = table1.c.col1 == table2.c.col1
1056
1057 make an alias of table1::
1058
1059 s = table1.alias("foo")
1060
1061 calling ``ClauseAdapter(s).traverse(condition)`` converts
1062 condition to read::
1063
1064 s.c.col1 == table2.c.col1
1065
1066 """
1067
1068 __slots__ = (
1069 "__traverse_options__",
1070 "selectable",
1071 "include_fn",
1072 "exclude_fn",
1073 "equivalents",
1074 "adapt_on_names",
1075 "adapt_from_selectables",
1076 )
1077
1078 def __init__(
1079 self,
1080 selectable: Selectable,
1081 equivalents: Optional[_EquivalentColumnMap] = None,
1082 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1083 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1084 adapt_on_names: bool = False,
1085 anonymize_labels: bool = False,
1086 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1087 ):
1088 self.__traverse_options__ = {
1089 "stop_on": [selectable],
1090 "anonymize_labels": anonymize_labels,
1091 }
1092 self.selectable = selectable
1093 self.include_fn = include_fn
1094 self.exclude_fn = exclude_fn
1095 self.equivalents = util.column_dict(equivalents or {})
1096 self.adapt_on_names = adapt_on_names
1097 self.adapt_from_selectables = adapt_from_selectables
1098
1099 if TYPE_CHECKING:
1100
1101 @overload
1102 def traverse(self, obj: Literal[None]) -> None: ...
1103
1104 # note this specializes the ReplacingExternalTraversal.traverse()
1105 # method to state
1106 # that we will return the same kind of ExternalTraversal object as
1107 # we were given. This is probably not 100% true, such as it's
1108 # possible for us to swap out Alias for Table at the top level.
1109 # Ideally there could be overloads specific to ColumnElement and
1110 # FromClause but Mypy is not accepting those as compatible with
1111 # the base ReplacingExternalTraversal
1112 @overload
1113 def traverse(self, obj: _ET) -> _ET: ...
1114
1115 def traverse(
1116 self, obj: Optional[ExternallyTraversible]
1117 ) -> Optional[ExternallyTraversible]: ...
1118
1119 def _corresponding_column(
1120 self, col, require_embedded, _seen=util.EMPTY_SET
1121 ):
1122 newcol = self.selectable.corresponding_column(
1123 col, require_embedded=require_embedded
1124 )
1125 if newcol is None and col in self.equivalents and col not in _seen:
1126 for equiv in self.equivalents[col]:
1127 newcol = self._corresponding_column(
1128 equiv,
1129 require_embedded=require_embedded,
1130 _seen=_seen.union([col]),
1131 )
1132 if newcol is not None:
1133 return newcol
1134
1135 if (
1136 self.adapt_on_names
1137 and newcol is None
1138 and isinstance(col, NamedColumn)
1139 ):
1140 newcol = self.selectable.exported_columns.get(col.name)
1141 return newcol
1142
1143 @util.preload_module("sqlalchemy.sql.functions")
1144 def replace(
1145 self, col: _ET, _include_singleton_constants: bool = False
1146 ) -> Optional[_ET]:
1147 functions = util.preloaded.sql_functions
1148
1149 # TODO: cython candidate
1150
1151 if self.include_fn and not self.include_fn(col): # type: ignore
1152 return None
1153 elif self.exclude_fn and self.exclude_fn(col): # type: ignore
1154 return None
1155
1156 if isinstance(col, FromClause) and not isinstance(
1157 col, functions.FunctionElement
1158 ):
1159 if self.selectable.is_derived_from(col):
1160 if self.adapt_from_selectables:
1161 for adp in self.adapt_from_selectables:
1162 if adp.is_derived_from(col):
1163 break
1164 else:
1165 return None
1166 return self.selectable # type: ignore
1167 elif isinstance(col, Alias) and isinstance(
1168 col.element, TableClause
1169 ):
1170 # we are a SELECT statement and not derived from an alias of a
1171 # table (which nonetheless may be a table our SELECT derives
1172 # from), so return the alias to prevent further traversal
1173 # or
1174 # we are an alias of a table and we are not derived from an
1175 # alias of a table (which nonetheless may be the same table
1176 # as ours) so, same thing
1177 return col # type: ignore
1178 else:
1179 # other cases where we are a selectable and the element
1180 # is another join or selectable that contains a table which our
1181 # selectable derives from, that we want to process
1182 return None
1183
1184 elif not isinstance(col, ColumnElement):
1185 return None
1186 elif not _include_singleton_constants and col._is_singleton_constant:
1187 # dont swap out NULL, TRUE, FALSE for a label name
1188 # in a SQL statement that's being rewritten,
1189 # leave them as the constant. This is first noted in #6259,
1190 # however the logic to check this moved here as of #7154 so that
1191 # it is made specific to SQL rewriting and not all column
1192 # correspondence
1193
1194 return None
1195
1196 if "adapt_column" in col._annotations:
1197 col = col._annotations["adapt_column"]
1198
1199 if TYPE_CHECKING:
1200 assert isinstance(col, KeyedColumnElement)
1201
1202 if self.adapt_from_selectables and col not in self.equivalents:
1203 for adp in self.adapt_from_selectables:
1204 if adp.c.corresponding_column(col, False) is not None:
1205 break
1206 else:
1207 return None
1208
1209 if TYPE_CHECKING:
1210 assert isinstance(col, KeyedColumnElement)
1211
1212 return self._corresponding_column( # type: ignore
1213 col, require_embedded=True
1214 )
1215
1216
1217class _ColumnLookup(Protocol):
1218 @overload
1219 def __getitem__(self, key: None) -> None: ...
1220
1221 @overload
1222 def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: ...
1223
1224 @overload
1225 def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: ...
1226
1227 @overload
1228 def __getitem__(self, key: _ET) -> _ET: ...
1229
1230 def __getitem__(self, key: Any) -> Any: ...
1231
1232
1233class ColumnAdapter(ClauseAdapter):
1234 """Extends ClauseAdapter with extra utility functions.
1235
1236 Key aspects of ColumnAdapter include:
1237
1238 * Expressions that are adapted are stored in a persistent
1239 .columns collection; so that an expression E adapted into
1240 an expression E1, will return the same object E1 when adapted
1241 a second time. This is important in particular for things like
1242 Label objects that are anonymized, so that the ColumnAdapter can
1243 be used to present a consistent "adapted" view of things.
1244
1245 * Exclusion of items from the persistent collection based on
1246 include/exclude rules, but also independent of hash identity.
1247 This because "annotated" items all have the same hash identity as their
1248 parent.
1249
1250 * "wrapping" capability is added, so that the replacement of an expression
1251 E can proceed through a series of adapters. This differs from the
1252 visitor's "chaining" feature in that the resulting object is passed
1253 through all replacing functions unconditionally, rather than stopping
1254 at the first one that returns non-None.
1255
1256 * An adapt_required option, used by eager loading to indicate that
1257 We don't trust a result row column that is not translated.
1258 This is to prevent a column from being interpreted as that
1259 of the child row in a self-referential scenario, see
1260 inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
1261
1262 """
1263
1264 __slots__ = (
1265 "columns",
1266 "adapt_required",
1267 "allow_label_resolve",
1268 "_wrap",
1269 "__weakref__",
1270 )
1271
1272 columns: _ColumnLookup
1273
1274 def __init__(
1275 self,
1276 selectable: Selectable,
1277 equivalents: Optional[_EquivalentColumnMap] = None,
1278 adapt_required: bool = False,
1279 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1280 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1281 adapt_on_names: bool = False,
1282 allow_label_resolve: bool = True,
1283 anonymize_labels: bool = False,
1284 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1285 ):
1286 super().__init__(
1287 selectable,
1288 equivalents,
1289 include_fn=include_fn,
1290 exclude_fn=exclude_fn,
1291 adapt_on_names=adapt_on_names,
1292 anonymize_labels=anonymize_labels,
1293 adapt_from_selectables=adapt_from_selectables,
1294 )
1295
1296 self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
1297 if self.include_fn or self.exclude_fn:
1298 self.columns = self._IncludeExcludeMapping(self, self.columns)
1299 self.adapt_required = adapt_required
1300 self.allow_label_resolve = allow_label_resolve
1301 self._wrap = None
1302
1303 class _IncludeExcludeMapping:
1304 def __init__(self, parent, columns):
1305 self.parent = parent
1306 self.columns = columns
1307
1308 def __getitem__(self, key):
1309 if (
1310 self.parent.include_fn and not self.parent.include_fn(key)
1311 ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
1312 if self.parent._wrap:
1313 return self.parent._wrap.columns[key]
1314 else:
1315 return key
1316 return self.columns[key]
1317
1318 def wrap(self, adapter):
1319 ac = copy.copy(self)
1320 ac._wrap = adapter
1321 ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore
1322 if ac.include_fn or ac.exclude_fn:
1323 ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
1324
1325 return ac
1326
1327 @overload
1328 def traverse(self, obj: Literal[None]) -> None: ...
1329
1330 @overload
1331 def traverse(self, obj: _ET) -> _ET: ...
1332
1333 def traverse(
1334 self, obj: Optional[ExternallyTraversible]
1335 ) -> Optional[ExternallyTraversible]:
1336 return self.columns[obj]
1337
1338 def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
1339 assert isinstance(visitor, ColumnAdapter)
1340
1341 return super().chain(visitor)
1342
1343 if TYPE_CHECKING:
1344
1345 @property
1346 def visitor_iterator(self) -> Iterator[ColumnAdapter]: ...
1347
1348 adapt_clause = traverse
1349 adapt_list = ClauseAdapter.copy_and_process
1350
1351 def adapt_check_present(
1352 self, col: ColumnElement[Any]
1353 ) -> Optional[ColumnElement[Any]]:
1354 newcol = self.columns[col]
1355
1356 if newcol is col and self._corresponding_column(col, True) is None:
1357 return None
1358
1359 return newcol
1360
1361 def _locate_col(
1362 self, col: ColumnElement[Any]
1363 ) -> Optional[ColumnElement[Any]]:
1364 # both replace and traverse() are overly complicated for what
1365 # we are doing here and we would do better to have an inlined
1366 # version that doesn't build up as much overhead. the issue is that
1367 # sometimes the lookup does in fact have to adapt the insides of
1368 # say a labeled scalar subquery. However, if the object is an
1369 # Immutable, i.e. Column objects, we can skip the "clone" /
1370 # "copy internals" part since those will be no-ops in any case.
1371 # additionally we want to catch singleton objects null/true/false
1372 # and make sure they are adapted as well here.
1373
1374 if col._is_immutable:
1375 for vis in self.visitor_iterator:
1376 c = vis.replace(col, _include_singleton_constants=True)
1377 if c is not None:
1378 break
1379 else:
1380 c = col
1381 else:
1382 c = ClauseAdapter.traverse(self, col)
1383
1384 if self._wrap:
1385 c2 = self._wrap._locate_col(c)
1386 if c2 is not None:
1387 c = c2
1388
1389 if self.adapt_required and c is col:
1390 return None
1391
1392 # allow_label_resolve is consumed by one case for joined eager loading
1393 # as part of its logic to prevent its own columns from being affected
1394 # by .order_by(). Before full typing were applied to the ORM, this
1395 # logic would set this attribute on the incoming object (which is
1396 # typically a column, but we have a test for it being a non-column
1397 # object) if no column were found. While this seemed to
1398 # have no negative effects, this adjustment should only occur on the
1399 # new column which is assumed to be local to an adapted selectable.
1400 if c is not col:
1401 c._allow_label_resolve = self.allow_label_resolve
1402
1403 return c
1404
1405
1406def _offset_or_limit_clause(
1407 element: _LimitOffsetType,
1408 name: Optional[str] = None,
1409 type_: Optional[_TypeEngineArgument[int]] = None,
1410) -> ColumnElement[int]:
1411 """Convert the given value to an "offset or limit" clause.
1412
1413 This handles incoming integers and converts to an expression; if
1414 an expression is already given, it is passed through.
1415
1416 """
1417 return coercions.expect(
1418 roles.LimitOffsetRole, element, name=name, type_=type_
1419 )
1420
1421
1422def _offset_or_limit_clause_asint_if_possible(
1423 clause: _LimitOffsetType,
1424) -> _LimitOffsetType:
1425 """Return the offset or limit clause as a simple integer if possible,
1426 else return the clause.
1427
1428 """
1429 if clause is None:
1430 return None
1431 if hasattr(clause, "_limit_offset_value"):
1432 value = clause._limit_offset_value
1433 return util.asint(value)
1434 else:
1435 return clause
1436
1437
1438def _make_slice(
1439 limit_clause: _LimitOffsetType,
1440 offset_clause: _LimitOffsetType,
1441 start: int,
1442 stop: int,
1443) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]:
1444 """Compute LIMIT/OFFSET in terms of slice start/end"""
1445
1446 # for calculated limit/offset, try to do the addition of
1447 # values to offset in Python, however if a SQL clause is present
1448 # then the addition has to be on the SQL side.
1449
1450 # TODO: typing is finding a few gaps in here, see if they can be
1451 # closed up
1452
1453 if start is not None and stop is not None:
1454 offset_clause = _offset_or_limit_clause_asint_if_possible(
1455 offset_clause
1456 )
1457 if offset_clause is None:
1458 offset_clause = 0
1459
1460 if start != 0:
1461 offset_clause = offset_clause + start # type: ignore
1462
1463 if offset_clause == 0:
1464 offset_clause = None
1465 else:
1466 assert offset_clause is not None
1467 offset_clause = _offset_or_limit_clause(offset_clause)
1468
1469 limit_clause = _offset_or_limit_clause(stop - start)
1470
1471 elif start is None and stop is not None:
1472 limit_clause = _offset_or_limit_clause(stop)
1473 elif start is not None and stop is None:
1474 offset_clause = _offset_or_limit_clause_asint_if_possible(
1475 offset_clause
1476 )
1477 if offset_clause is None:
1478 offset_clause = 0
1479
1480 if start != 0:
1481 offset_clause = offset_clause + start
1482
1483 if offset_clause == 0:
1484 offset_clause = None
1485 else:
1486 offset_clause = _offset_or_limit_clause(offset_clause)
1487
1488 return limit_clause, offset_clause