Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy/sql/util.py: 23%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# sql/util.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
9"""High level utilities which build upon other modules here."""
10from __future__ import annotations
12from collections import deque
13import copy
14from itertools import chain
15import typing
16from typing import AbstractSet
17from typing import Any
18from typing import Callable
19from typing import cast
20from typing import Collection
21from typing import Dict
22from typing import Iterable
23from typing import Iterator
24from typing import List
25from typing import Literal
26from typing import Optional
27from typing import overload
28from typing import Protocol
29from typing import Sequence
30from typing import Tuple
31from typing import TYPE_CHECKING
32from typing import TypeVar
33from typing import Union
35from . import coercions
36from . import operators
37from . import roles
38from . import visitors
39from ._typing import is_text_clause
40from .annotation import _deep_annotate as _deep_annotate # noqa: F401
41from .annotation import _deep_deannotate as _deep_deannotate # noqa: F401
42from .annotation import _shallow_annotate as _shallow_annotate # noqa: F401
43from .base import _expand_cloned
44from .base import _from_objects
45from .cache_key import HasCacheKey as HasCacheKey # noqa: F401
46from .ddl import sort_tables as sort_tables # noqa: F401
47from .elements import _find_columns as _find_columns
48from .elements import _label_reference
49from .elements import _textual_label_reference
50from .elements import BindParameter
51from .elements import ClauseElement
52from .elements import ColumnClause
53from .elements import ColumnElement
54from .elements import Grouping
55from .elements import KeyedColumnElement
56from .elements import Label
57from .elements import NamedColumn
58from .elements import Null
59from .elements import UnaryExpression
60from .schema import Column
61from .selectable import Alias
62from .selectable import FromClause
63from .selectable import FromGrouping
64from .selectable import Join
65from .selectable import ScalarSelect
66from .selectable import SelectBase
67from .selectable import TableClause
68from .visitors import _ET
69from .. import exc
70from .. import util
71from ..util.typing import Unpack
73if typing.TYPE_CHECKING:
74 from ._typing import _EquivalentColumnMap
75 from ._typing import _LimitOffsetType
76 from ._typing import _TypeEngineArgument
77 from .elements import AbstractTextClause
78 from .elements import BinaryExpression
79 from .selectable import _JoinTargetElement
80 from .selectable import _SelectIterable
81 from .selectable import Selectable
82 from .visitors import _TraverseCallableType
83 from .visitors import ExternallyTraversible
84 from .visitors import ExternalTraversal
85 from ..engine.interfaces import _AnyExecuteParams
86 from ..engine.interfaces import _AnyMultiExecuteParams
87 from ..engine.interfaces import _AnySingleExecuteParams
88 from ..engine.interfaces import _CoreSingleExecuteParams
89 from ..engine.row import Row
91_CE = TypeVar("_CE", bound="ColumnElement[Any]")
94def join_condition(
95 a: FromClause,
96 b: FromClause,
97 a_subset: Optional[FromClause] = None,
98 consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
99) -> ColumnElement[bool]:
100 """Create a join condition between two tables or selectables.
102 e.g.::
104 join_condition(tablea, tableb)
106 would produce an expression along the lines of::
108 tablea.c.id == tableb.c.tablea_id
110 The join is determined based on the foreign key relationships
111 between the two selectables. If there are multiple ways
112 to join, or no way to join, an error is raised.
114 :param a_subset: An optional expression that is a sub-component
115 of ``a``. An attempt will be made to join to just this sub-component
116 first before looking at the full ``a`` construct, and if found
117 will be successful even if there are other ways to join to ``a``.
118 This allows the "right side" of a join to be passed thereby
119 providing a "natural join".
121 """
122 return Join._join_condition(
123 a,
124 b,
125 a_subset=a_subset,
126 consider_as_foreign_keys=consider_as_foreign_keys,
127 )
130def find_join_source(
131 clauses: List[FromClause], join_to: FromClause
132) -> List[int]:
133 """Given a list of FROM clauses and a selectable,
134 return the first index and element from the list of
135 clauses which can be joined against the selectable. returns
136 None, None if no match is found.
138 e.g.::
140 clause1 = table1.join(table2)
141 clause2 = table4.join(table5)
143 join_to = table2.join(table3)
145 find_join_source([clause1, clause2], join_to) == clause1
147 """
149 selectables = list(_from_objects(join_to))
150 idx = []
151 for i, f in enumerate(clauses):
152 for s in selectables:
153 if f.is_derived_from(s):
154 idx.append(i)
155 return idx
158def find_left_clause_that_matches_given(
159 clauses: Sequence[FromClause], join_from: FromClause
160) -> List[int]:
161 """Given a list of FROM clauses and a selectable,
162 return the indexes from the list of
163 clauses which is derived from the selectable.
165 """
167 selectables = list(_from_objects(join_from))
168 liberal_idx = []
169 for i, f in enumerate(clauses):
170 for s in selectables:
171 # basic check, if f is derived from s.
172 # this can be joins containing a table, or an aliased table
173 # or select statement matching to a table. This check
174 # will match a table to a selectable that is adapted from
175 # that table. With Query, this suits the case where a join
176 # is being made to an adapted entity
177 if f.is_derived_from(s):
178 liberal_idx.append(i)
179 break
181 # in an extremely small set of use cases, a join is being made where
182 # there are multiple FROM clauses where our target table is represented
183 # in more than one, such as embedded or similar. in this case, do
184 # another pass where we try to get a more exact match where we aren't
185 # looking at adaption relationships.
186 if len(liberal_idx) > 1:
187 conservative_idx = []
188 for idx in liberal_idx:
189 f = clauses[idx]
190 for s in selectables:
191 if set(surface_selectables(f)).intersection(
192 surface_selectables(s)
193 ):
194 conservative_idx.append(idx)
195 break
196 if conservative_idx:
197 return conservative_idx
199 return liberal_idx
202def find_left_clause_to_join_from(
203 clauses: Sequence[FromClause],
204 join_to: _JoinTargetElement,
205 onclause: Optional[ColumnElement[Any]],
206) -> List[int]:
207 """Given a list of FROM clauses, a selectable,
208 and optional ON clause, return a list of integer indexes from the
209 clauses list indicating the clauses that can be joined from.
211 The presence of an "onclause" indicates that at least one clause can
212 definitely be joined from; if the list of clauses is of length one
213 and the onclause is given, returns that index. If the list of clauses
214 is more than length one, and the onclause is given, attempts to locate
215 which clauses contain the same columns.
217 """
218 idx = []
219 selectables = set(_from_objects(join_to))
221 # if we are given more than one target clause to join
222 # from, use the onclause to provide a more specific answer.
223 # otherwise, don't try to limit, after all, "ON TRUE" is a valid
224 # on clause
225 if len(clauses) > 1 and onclause is not None:
226 resolve_ambiguity = True
227 cols_in_onclause = _find_columns(onclause)
228 else:
229 resolve_ambiguity = False
230 cols_in_onclause = None
232 for i, f in enumerate(clauses):
233 for s in selectables.difference([f]):
234 if resolve_ambiguity:
235 assert cols_in_onclause is not None
236 if set(f.c).union(s.c).issuperset(cols_in_onclause):
237 idx.append(i)
238 break
239 elif onclause is not None or Join._can_join(f, s):
240 idx.append(i)
241 break
243 if len(idx) > 1:
244 # this is the same "hide froms" logic from
245 # Selectable._get_display_froms
246 toremove = set(
247 chain(*[_expand_cloned(f._hide_froms) for f in clauses])
248 )
249 idx = [i for i in idx if clauses[i] not in toremove]
251 # onclause was given and none of them resolved, so assume
252 # all indexes can match
253 if not idx and onclause is not None:
254 return list(range(len(clauses)))
255 else:
256 return idx
259def visit_binary_product(
260 fn: Callable[
261 [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
262 ],
263 expr: ColumnElement[Any],
264) -> None:
265 """Produce a traversal of the given expression, delivering
266 column comparisons to the given function.
268 The function is of the form::
270 def my_fn(binary, left, right): ...
272 For each binary expression located which has a
273 comparison operator, the product of "left" and
274 "right" will be delivered to that function,
275 in terms of that binary.
277 Hence an expression like::
279 and_((a + b) == q + func.sum(e + f), j == r)
281 would have the traversal:
283 .. sourcecode:: text
285 a <eq> q
286 a <eq> e
287 a <eq> f
288 b <eq> q
289 b <eq> e
290 b <eq> f
291 j <eq> r
293 That is, every combination of "left" and
294 "right" that doesn't further contain
295 a binary comparison is passed as pairs.
297 """
298 stack: List[BinaryExpression[Any]] = []
300 def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
301 if isinstance(element, ScalarSelect):
302 # we don't want to dig into correlated subqueries,
303 # those are just column elements by themselves
304 yield element
305 elif element.__visit_name__ == "binary" and operators.is_comparison(
306 element.operator # type: ignore
307 ):
308 stack.insert(0, element) # type: ignore
309 for l in visit(element.left): # type: ignore
310 for r in visit(element.right): # type: ignore
311 fn(stack[0], l, r)
312 stack.pop(0)
313 for elem in element.get_children():
314 visit(elem)
315 else:
316 if isinstance(element, ColumnClause):
317 yield element
318 for elem in element.get_children():
319 yield from visit(elem)
321 list(visit(expr))
322 visit = None # type: ignore # remove gc cycles
325def find_tables(
326 clause: ClauseElement,
327 *,
328 check_columns: bool = False,
329 include_aliases: bool = False,
330 include_joins: bool = False,
331 include_selects: bool = False,
332 include_crud: bool = False,
333) -> List[TableClause]:
334 """locate Table objects within the given expression."""
336 tables: List[TableClause] = []
337 _visitors: Dict[str, _TraverseCallableType[Any]] = {}
339 if include_selects:
340 _visitors["select"] = _visitors["compound_select"] = tables.append
342 if include_joins:
343 _visitors["join"] = tables.append
345 if include_aliases:
346 _visitors["alias"] = _visitors["subquery"] = _visitors[
347 "tablesample"
348 ] = _visitors["lateral"] = tables.append
350 if include_crud:
351 _visitors["insert"] = _visitors["update"] = _visitors["delete"] = (
352 lambda ent: tables.append(ent.table)
353 )
355 if check_columns:
357 def visit_column(column):
358 tables.append(column.table)
360 _visitors["column"] = visit_column
362 _visitors["table"] = tables.append
364 visitors.traverse(clause, {}, _visitors)
365 return tables
368def unwrap_order_by(clause: Any) -> Any:
369 """Break up an 'order by' expression into individual column-expressions,
370 without DESC/ASC/NULLS FIRST/NULLS LAST"""
372 cols = util.column_set()
373 result = []
374 stack = deque([clause])
376 # examples
377 # column -> ASC/DESC == column
378 # column -> ASC/DESC -> label == column
379 # column -> label -> ASC/DESC -> label == column
380 # scalar_select -> label -> ASC/DESC == scalar_select -> label
382 while stack:
383 t = stack.popleft()
384 if isinstance(t, ColumnElement) and (
385 not isinstance(t, UnaryExpression)
386 or not operators.is_ordering_modifier(t.modifier) # type: ignore
387 ):
388 if isinstance(t, Label) and not isinstance(
389 t.element, ScalarSelect
390 ):
391 t = t.element
393 if isinstance(t, Grouping):
394 t = t.element
396 stack.append(t)
397 continue
398 elif isinstance(t, _label_reference):
399 t = t.element
401 stack.append(t)
402 continue
403 if isinstance(t, (_textual_label_reference)):
404 continue
405 if t not in cols:
406 cols.add(t)
407 result.append(t)
409 else:
410 for c in t.get_children():
411 stack.append(c)
412 return result
415def unwrap_label_reference(element):
416 def replace(
417 element: ExternallyTraversible, **kw: Any
418 ) -> Optional[ExternallyTraversible]:
419 if isinstance(element, _label_reference):
420 return element.element
421 elif isinstance(element, _textual_label_reference):
422 assert False, "can't unwrap a textual label reference"
423 return None
425 return visitors.replacement_traverse(element, {}, replace)
428def expand_column_list_from_order_by(collist, order_by):
429 """Given the columns clause and ORDER BY of a selectable,
430 return a list of column expressions that can be added to the collist
431 corresponding to the ORDER BY, without repeating those already
432 in the collist.
434 """
435 cols_already_present = {
436 col.element if col._order_by_label_element is not None else col
437 for col in collist
438 }
440 to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
442 return [col for col in to_look_for if col not in cols_already_present]
445def clause_is_present(clause, search):
446 """Given a target clause and a second to search within, return True
447 if the target is plainly present in the search without any
448 subqueries or aliases involved.
450 Basically descends through Joins.
452 """
454 for elem in surface_selectables(search):
455 if clause == elem: # use == here so that Annotated's compare
456 return True
457 else:
458 return False
461def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
462 if isinstance(clause, Join):
463 yield from tables_from_leftmost(clause.left)
464 yield from tables_from_leftmost(clause.right)
465 elif isinstance(clause, FromGrouping):
466 yield from tables_from_leftmost(clause.element)
467 else:
468 yield clause
471def surface_expressions(clause):
472 stack = [clause]
473 while stack:
474 elem = stack.pop()
475 yield elem
476 if isinstance(elem, ColumnElement):
477 stack.extend(elem.get_children())
480def surface_selectables(clause):
481 stack = [clause]
482 while stack:
483 elem = stack.pop()
484 yield elem
485 if isinstance(elem, Join):
486 stack.extend((elem.left, elem.right))
487 elif isinstance(elem, FromGrouping):
488 stack.append(elem.element)
491def surface_selectables_only(clause: ClauseElement) -> Iterator[ClauseElement]:
492 stack = [clause]
493 while stack:
494 elem = stack.pop()
495 if isinstance(elem, (TableClause, Alias)):
496 yield elem
497 if isinstance(elem, Join):
498 stack.extend((elem.left, elem.right))
499 elif isinstance(elem, FromGrouping):
500 stack.append(elem.element)
501 elif isinstance(elem, ColumnClause):
502 if elem.table is not None:
503 stack.append(elem.table)
504 else:
505 yield elem
506 elif elem is not None:
507 yield elem
510def extract_first_column_annotation(column, annotation_name):
511 filter_ = (FromGrouping, SelectBase)
513 stack = deque([column])
514 while stack:
515 elem = stack.popleft()
516 if annotation_name in elem._annotations:
517 return elem._annotations[annotation_name]
518 for sub in elem.get_children():
519 if isinstance(sub, filter_):
520 continue
521 stack.append(sub)
522 return None
525def selectables_overlap(left: FromClause, right: FromClause) -> bool:
526 """Return True if left/right have some overlapping selectable"""
528 return bool(
529 set(surface_selectables(left)).intersection(surface_selectables(right))
530 )
533def bind_values(clause):
534 """Return an ordered list of "bound" values in the given clause.
536 E.g.::
538 >>> expr = and_(table.c.foo == 5, table.c.foo == 7)
539 >>> bind_values(expr)
540 [5, 7]
541 """
543 v = []
545 def visit_bindparam(bind):
546 v.append(bind.effective_value)
548 visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
549 return v
552def _quote_ddl_expr(element):
553 if isinstance(element, str):
554 element = element.replace("'", "''")
555 return "'%s'" % element
556 else:
557 return repr(element)
560class _repr_base:
561 _LIST: int = 0
562 _TUPLE: int = 1
563 _DICT: int = 2
565 __slots__ = ("max_chars",)
567 max_chars: int
569 def trunc(self, value: Any) -> str:
570 rep = repr(value)
571 lenrep = len(rep)
572 if lenrep > self.max_chars:
573 segment_length = self.max_chars // 2
574 rep = (
575 rep[0:segment_length]
576 + (
577 " ... (%d characters truncated) ... "
578 % (lenrep - self.max_chars)
579 )
580 + rep[-segment_length:]
581 )
582 return rep
585def _repr_single_value(value):
586 rp = _repr_base()
587 rp.max_chars = 300
588 return rp.trunc(value)
591class _repr_row(_repr_base):
592 """Provide a string view of a row."""
594 __slots__ = ("row",)
596 def __init__(
597 self, row: Row[Unpack[Tuple[Any, ...]]], max_chars: int = 300
598 ):
599 self.row = row
600 self.max_chars = max_chars
602 def __repr__(self) -> str:
603 trunc = self.trunc
604 return "(%s%s)" % (
605 ", ".join(trunc(value) for value in self.row),
606 "," if len(self.row) == 1 else "",
607 )
610class _long_statement(str):
611 def __str__(self) -> str:
612 lself = len(self)
613 if lself > 500:
614 lleft = 250
615 lright = 100
616 trunc = lself - lleft - lright
617 return (
618 f"{self[0:lleft]} ... {trunc} "
619 f"characters truncated ... {self[-lright:]}"
620 )
621 else:
622 return str.__str__(self)
625class _repr_params(_repr_base):
626 """Provide a string view of bound parameters.
628 Truncates display to a given number of 'multi' parameter sets,
629 as well as long values to a given number of characters.
631 """
633 __slots__ = "params", "batches", "ismulti", "max_params"
635 def __init__(
636 self,
637 params: Optional[_AnyExecuteParams],
638 batches: int,
639 max_params: int = 100,
640 max_chars: int = 300,
641 ismulti: Optional[bool] = None,
642 ):
643 self.params = params
644 self.ismulti = ismulti
645 self.batches = batches
646 self.max_chars = max_chars
647 self.max_params = max_params
649 def __repr__(self) -> str:
650 if self.ismulti is None:
651 return self.trunc(self.params)
653 if isinstance(self.params, list):
654 typ = self._LIST
656 elif isinstance(self.params, tuple):
657 typ = self._TUPLE
658 elif isinstance(self.params, dict):
659 typ = self._DICT
660 else:
661 return self.trunc(self.params)
663 if self.ismulti:
664 multi_params = cast(
665 "_AnyMultiExecuteParams",
666 self.params,
667 )
669 if len(self.params) > self.batches:
670 msg = (
671 " ... displaying %i of %i total bound parameter sets ... "
672 )
673 return " ".join(
674 (
675 self._repr_multi(
676 multi_params[: self.batches - 2],
677 typ,
678 )[0:-1],
679 msg % (self.batches, len(self.params)),
680 self._repr_multi(multi_params[-2:], typ)[1:],
681 )
682 )
683 else:
684 return self._repr_multi(multi_params, typ)
685 else:
686 return self._repr_params(
687 cast(
688 "_AnySingleExecuteParams",
689 self.params,
690 ),
691 typ,
692 )
694 def _repr_multi(
695 self,
696 multi_params: _AnyMultiExecuteParams,
697 typ: int,
698 ) -> str:
699 if multi_params:
700 if isinstance(multi_params[0], list):
701 elem_type = self._LIST
702 elif isinstance(multi_params[0], tuple):
703 elem_type = self._TUPLE
704 elif isinstance(multi_params[0], dict):
705 elem_type = self._DICT
706 else:
707 assert False, "Unknown parameter type %s" % (
708 type(multi_params[0])
709 )
711 elements = ", ".join(
712 self._repr_params(params, elem_type) for params in multi_params
713 )
714 else:
715 elements = ""
717 if typ == self._LIST:
718 return "[%s]" % elements
719 else:
720 return "(%s)" % elements
722 def _get_batches(self, params: Iterable[Any]) -> Any:
723 lparams = list(params)
724 lenparams = len(lparams)
725 if lenparams > self.max_params:
726 lleft = self.max_params // 2
727 return (
728 lparams[0:lleft],
729 lparams[-lleft:],
730 lenparams - self.max_params,
731 )
732 else:
733 return lparams, None, None
735 def _repr_params(
736 self,
737 params: _AnySingleExecuteParams,
738 typ: int,
739 ) -> str:
740 if typ is self._DICT:
741 return self._repr_param_dict(
742 cast("_CoreSingleExecuteParams", params)
743 )
744 elif typ is self._TUPLE:
745 return self._repr_param_tuple(cast("Sequence[Any]", params))
746 else:
747 return self._repr_param_list(params)
749 def _repr_param_dict(self, params: _CoreSingleExecuteParams) -> str:
750 trunc = self.trunc
751 (
752 items_first_batch,
753 items_second_batch,
754 trunclen,
755 ) = self._get_batches(params.items())
757 if items_second_batch:
758 text = "{%s" % (
759 ", ".join(
760 f"{key!r}: {trunc(value)}"
761 for key, value in items_first_batch
762 )
763 )
764 text += f" ... {trunclen} parameters truncated ... "
765 text += "%s}" % (
766 ", ".join(
767 f"{key!r}: {trunc(value)}"
768 for key, value in items_second_batch
769 )
770 )
771 else:
772 text = "{%s}" % (
773 ", ".join(
774 f"{key!r}: {trunc(value)}"
775 for key, value in items_first_batch
776 )
777 )
778 return text
780 def _repr_param_tuple(self, params: Sequence[Any]) -> str:
781 trunc = self.trunc
783 (
784 items_first_batch,
785 items_second_batch,
786 trunclen,
787 ) = self._get_batches(params)
789 if items_second_batch:
790 text = "(%s" % (
791 ", ".join(trunc(value) for value in items_first_batch)
792 )
793 text += f" ... {trunclen} parameters truncated ... "
794 text += "%s)" % (
795 ", ".join(trunc(value) for value in items_second_batch),
796 )
797 else:
798 text = "(%s%s)" % (
799 ", ".join(trunc(value) for value in items_first_batch),
800 "," if len(items_first_batch) == 1 else "",
801 )
802 return text
804 def _repr_param_list(self, params: _AnySingleExecuteParams) -> str:
805 trunc = self.trunc
806 (
807 items_first_batch,
808 items_second_batch,
809 trunclen,
810 ) = self._get_batches(params)
812 if items_second_batch:
813 text = "[%s" % (
814 ", ".join(trunc(value) for value in items_first_batch)
815 )
816 text += f" ... {trunclen} parameters truncated ... "
817 text += "%s]" % (
818 ", ".join(trunc(value) for value in items_second_batch)
819 )
820 else:
821 text = "[%s]" % (
822 ", ".join(trunc(value) for value in items_first_batch)
823 )
824 return text
827def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
828 """given criterion containing bind params, convert selected elements
829 to IS NULL.
831 """
833 def visit_binary(binary):
834 if (
835 isinstance(binary.left, BindParameter)
836 and binary.left._identifying_key in nulls
837 ):
838 # reverse order if the NULL is on the left side
839 binary.left = binary.right
840 binary.right = Null()
841 binary.operator = operators.is_
842 binary.negate = operators.is_not
843 elif (
844 isinstance(binary.right, BindParameter)
845 and binary.right._identifying_key in nulls
846 ):
847 binary.right = Null()
848 binary.operator = operators.is_
849 binary.negate = operators.is_not
851 return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
854def splice_joins(
855 left: Optional[FromClause],
856 right: Optional[FromClause],
857 stop_on: Optional[FromClause] = None,
858) -> Optional[FromClause]:
859 if left is None:
860 return right
862 stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)]
864 adapter = ClauseAdapter(left)
865 ret = None
866 while stack:
867 (right, prevright) = stack.pop()
868 if isinstance(right, Join) and right is not stop_on:
869 right = right._clone()
870 right.onclause = adapter.traverse(right.onclause)
871 stack.append((right.left, right))
872 else:
873 right = adapter.traverse(right)
874 if prevright is not None:
875 assert right is not None
876 prevright.left = right
877 if ret is None:
878 ret = right
880 return ret
883@overload
884def reduce_columns(
885 columns: Iterable[ColumnElement[Any]],
886 *clauses: Optional[ClauseElement],
887 **kw: bool,
888) -> Sequence[ColumnElement[Any]]: ...
891@overload
892def reduce_columns(
893 columns: _SelectIterable,
894 *clauses: Optional[ClauseElement],
895 **kw: bool,
896) -> Sequence[Union[ColumnElement[Any], AbstractTextClause]]: ...
899def reduce_columns(
900 columns: _SelectIterable,
901 *clauses: Optional[ClauseElement],
902 **kw: bool,
903) -> Collection[Union[ColumnElement[Any], AbstractTextClause]]:
904 r"""given a list of columns, return a 'reduced' set based on natural
905 equivalents.
907 the set is reduced to the smallest list of columns which have no natural
908 equivalent present in the list. A "natural equivalent" means that two
909 columns will ultimately represent the same value because they are related
910 by a foreign key.
912 \*clauses is an optional list of join clauses which will be traversed
913 to further identify columns that are "equivalent".
915 \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
916 whose tables are not yet configured, or columns that aren't yet present.
918 This function is primarily used to determine the most minimal "primary
919 key" from a selectable, by reducing the set of primary key columns present
920 in the selectable to just those that are not repeated.
922 """
923 ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
924 only_synonyms = kw.pop("only_synonyms", False)
926 column_set = util.OrderedSet(columns)
927 cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
928 c for c in column_set if is_text_clause(c) # type: ignore
929 )
931 omit = util.column_set()
932 for col in cset_no_text:
933 for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
934 for c in cset_no_text:
935 if c is col:
936 continue
937 try:
938 fk_col = fk.column
939 except exc.NoReferencedColumnError:
940 # TODO: add specific coverage here
941 # to test/sql/test_selectable ReduceTest
942 if ignore_nonexistent_tables:
943 continue
944 else:
945 raise
946 except exc.NoReferencedTableError:
947 # TODO: add specific coverage here
948 # to test/sql/test_selectable ReduceTest
949 if ignore_nonexistent_tables:
950 continue
951 else:
952 raise
953 if fk_col.shares_lineage(c) and (
954 not only_synonyms or c.name == col.name
955 ):
956 omit.add(col)
957 break
959 if clauses:
961 def visit_binary(binary):
962 if binary.operator == operators.eq:
963 cols = util.column_set(
964 chain(
965 *[c.proxy_set for c in cset_no_text.difference(omit)]
966 )
967 )
968 if binary.left in cols and binary.right in cols:
969 for c in reversed(cset_no_text):
970 if c.shares_lineage(binary.right) and (
971 not only_synonyms or c.name == binary.left.name
972 ):
973 omit.add(c)
974 break
976 for clause in clauses:
977 if clause is not None:
978 visitors.traverse(clause, {}, {"binary": visit_binary})
980 return column_set.difference(omit)
983def criterion_as_pairs(
984 expression,
985 consider_as_foreign_keys=None,
986 consider_as_referenced_keys=None,
987 any_operator=False,
988):
989 """traverse an expression and locate binary criterion pairs."""
991 if consider_as_foreign_keys and consider_as_referenced_keys:
992 raise exc.ArgumentError(
993 "Can only specify one of "
994 "'consider_as_foreign_keys' or "
995 "'consider_as_referenced_keys'"
996 )
998 def col_is(a, b):
999 # return a is b
1000 return a.compare(b)
1002 def visit_binary(binary):
1003 if not any_operator and binary.operator is not operators.eq:
1004 return
1005 if not isinstance(binary.left, ColumnElement) or not isinstance(
1006 binary.right, ColumnElement
1007 ):
1008 return
1010 if consider_as_foreign_keys:
1011 if binary.left in consider_as_foreign_keys and (
1012 col_is(binary.right, binary.left)
1013 or binary.right not in consider_as_foreign_keys
1014 ):
1015 pairs.append((binary.right, binary.left))
1016 elif binary.right in consider_as_foreign_keys and (
1017 col_is(binary.left, binary.right)
1018 or binary.left not in consider_as_foreign_keys
1019 ):
1020 pairs.append((binary.left, binary.right))
1021 elif consider_as_referenced_keys:
1022 if binary.left in consider_as_referenced_keys and (
1023 col_is(binary.right, binary.left)
1024 or binary.right not in consider_as_referenced_keys
1025 ):
1026 pairs.append((binary.left, binary.right))
1027 elif binary.right in consider_as_referenced_keys and (
1028 col_is(binary.left, binary.right)
1029 or binary.left not in consider_as_referenced_keys
1030 ):
1031 pairs.append((binary.right, binary.left))
1032 else:
1033 if isinstance(binary.left, Column) and isinstance(
1034 binary.right, Column
1035 ):
1036 if binary.left.references(binary.right):
1037 pairs.append((binary.right, binary.left))
1038 elif binary.right.references(binary.left):
1039 pairs.append((binary.left, binary.right))
1041 pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = []
1042 visitors.traverse(expression, {}, {"binary": visit_binary})
1043 return pairs
1046class ClauseAdapter(visitors.ReplacingExternalTraversal):
1047 """Clones and modifies clauses based on column correspondence.
1049 E.g.::
1051 table1 = Table(
1052 "sometable",
1053 metadata,
1054 Column("col1", Integer),
1055 Column("col2", Integer),
1056 )
1057 table2 = Table(
1058 "someothertable",
1059 metadata,
1060 Column("col1", Integer),
1061 Column("col2", Integer),
1062 )
1064 condition = table1.c.col1 == table2.c.col1
1066 make an alias of table1::
1068 s = table1.alias("foo")
1070 calling ``ClauseAdapter(s).traverse(condition)`` converts
1071 condition to read::
1073 s.c.col1 == table2.c.col1
1075 """
1077 __slots__ = (
1078 "__traverse_options__",
1079 "selectable",
1080 "include_fn",
1081 "exclude_fn",
1082 "equivalents",
1083 "adapt_on_names",
1084 "adapt_from_selectables",
1085 )
1087 def __init__(
1088 self,
1089 selectable: Selectable,
1090 equivalents: Optional[_EquivalentColumnMap] = None,
1091 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1092 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1093 adapt_on_names: bool = False,
1094 anonymize_labels: bool = False,
1095 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1096 ):
1097 self.__traverse_options__ = {
1098 "stop_on": [selectable],
1099 "anonymize_labels": anonymize_labels,
1100 }
1101 self.selectable = selectable
1102 self.include_fn = include_fn
1103 self.exclude_fn = exclude_fn
1104 self.equivalents = util.column_dict(equivalents or {})
1105 self.adapt_on_names = adapt_on_names
1106 self.adapt_from_selectables = adapt_from_selectables
1108 if TYPE_CHECKING:
1110 @overload
1111 def traverse(self, obj: Literal[None]) -> None: ...
1113 # note this specializes the ReplacingExternalTraversal.traverse()
1114 # method to state
1115 # that we will return the same kind of ExternalTraversal object as
1116 # we were given. This is probably not 100% true, such as it's
1117 # possible for us to swap out Alias for Table at the top level.
1118 # Ideally there could be overloads specific to ColumnElement and
1119 # FromClause but Mypy is not accepting those as compatible with
1120 # the base ReplacingExternalTraversal
1121 @overload
1122 def traverse(self, obj: _ET) -> _ET: ...
1124 def traverse(
1125 self, obj: Optional[ExternallyTraversible]
1126 ) -> Optional[ExternallyTraversible]: ...
1128 def _corresponding_column(
1129 self, col, require_embedded, _seen=util.EMPTY_SET
1130 ):
1131 newcol = self.selectable.corresponding_column(
1132 col, require_embedded=require_embedded
1133 )
1134 if newcol is None and col in self.equivalents and col not in _seen:
1135 for equiv in self.equivalents[col]:
1136 newcol = self._corresponding_column(
1137 equiv,
1138 require_embedded=require_embedded,
1139 _seen=_seen.union([col]),
1140 )
1141 if newcol is not None:
1142 return newcol
1144 if (
1145 self.adapt_on_names
1146 and newcol is None
1147 and isinstance(col, NamedColumn)
1148 ):
1149 newcol = self.selectable.exported_columns.get(col.name)
1150 return newcol
1152 @util.preload_module("sqlalchemy.sql.functions")
1153 def replace(
1154 self, col: _ET, _include_singleton_constants: bool = False
1155 ) -> Optional[_ET]:
1156 functions = util.preloaded.sql_functions
1158 # TODO: cython candidate
1160 if self.include_fn and not self.include_fn(col): # type: ignore
1161 return None
1162 elif self.exclude_fn and self.exclude_fn(col): # type: ignore
1163 return None
1165 if isinstance(col, FromClause) and not isinstance(
1166 col, functions.FunctionElement
1167 ):
1168 if self.selectable.is_derived_from(col):
1169 if self.adapt_from_selectables:
1170 for adp in self.adapt_from_selectables:
1171 if adp.is_derived_from(col):
1172 break
1173 else:
1174 return None
1175 return self.selectable # type: ignore
1176 elif isinstance(col, Alias) and isinstance(
1177 col.element, TableClause
1178 ):
1179 # we are a SELECT statement and not derived from an alias of a
1180 # table (which nonetheless may be a table our SELECT derives
1181 # from), so return the alias to prevent further traversal
1182 # or
1183 # we are an alias of a table and we are not derived from an
1184 # alias of a table (which nonetheless may be the same table
1185 # as ours) so, same thing
1186 return col
1187 else:
1188 # other cases where we are a selectable and the element
1189 # is another join or selectable that contains a table which our
1190 # selectable derives from, that we want to process
1191 return None
1193 elif not isinstance(col, ColumnElement):
1194 return None
1195 elif not _include_singleton_constants and col._is_singleton_constant:
1196 # dont swap out NULL, TRUE, FALSE for a label name
1197 # in a SQL statement that's being rewritten,
1198 # leave them as the constant. This is first noted in #6259,
1199 # however the logic to check this moved here as of #7154 so that
1200 # it is made specific to SQL rewriting and not all column
1201 # correspondence
1203 return None
1205 if "adapt_column" in col._annotations:
1206 col = col._annotations["adapt_column"]
1208 if TYPE_CHECKING:
1209 assert isinstance(col, KeyedColumnElement)
1211 if self.adapt_from_selectables and col not in self.equivalents:
1212 for adp in self.adapt_from_selectables:
1213 if adp.c.corresponding_column(col, False) is not None:
1214 break
1215 else:
1216 return None
1218 if TYPE_CHECKING:
1219 assert isinstance(col, KeyedColumnElement)
1221 return self._corresponding_column( # type: ignore
1222 col, require_embedded=True
1223 )
1226class _ColumnLookup(Protocol):
1227 @overload
1228 def __getitem__(self, key: None) -> None: ...
1230 @overload
1231 def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: ...
1233 @overload
1234 def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: ...
1236 @overload
1237 def __getitem__(self, key: _ET) -> _ET: ...
1239 def __getitem__(self, key: Any) -> Any: ...
1242class ColumnAdapter(ClauseAdapter):
1243 """Extends ClauseAdapter with extra utility functions.
1245 Key aspects of ColumnAdapter include:
1247 * Expressions that are adapted are stored in a persistent
1248 .columns collection; so that an expression E adapted into
1249 an expression E1, will return the same object E1 when adapted
1250 a second time. This is important in particular for things like
1251 Label objects that are anonymized, so that the ColumnAdapter can
1252 be used to present a consistent "adapted" view of things.
1254 * Exclusion of items from the persistent collection based on
1255 include/exclude rules, but also independent of hash identity.
1256 This because "annotated" items all have the same hash identity as their
1257 parent.
1259 * "wrapping" capability is added, so that the replacement of an expression
1260 E can proceed through a series of adapters. This differs from the
1261 visitor's "chaining" feature in that the resulting object is passed
1262 through all replacing functions unconditionally, rather than stopping
1263 at the first one that returns non-None.
1265 * An adapt_required option, used by eager loading to indicate that
1266 We don't trust a result row column that is not translated.
1267 This is to prevent a column from being interpreted as that
1268 of the child row in a self-referential scenario, see
1269 inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
1271 """
1273 __slots__ = (
1274 "columns",
1275 "adapt_required",
1276 "allow_label_resolve",
1277 "_wrap",
1278 "__weakref__",
1279 )
1281 columns: _ColumnLookup
1283 def __init__(
1284 self,
1285 selectable: Selectable,
1286 equivalents: Optional[_EquivalentColumnMap] = None,
1287 adapt_required: bool = False,
1288 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1289 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1290 adapt_on_names: bool = False,
1291 allow_label_resolve: bool = True,
1292 anonymize_labels: bool = False,
1293 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1294 ):
1295 super().__init__(
1296 selectable,
1297 equivalents,
1298 include_fn=include_fn,
1299 exclude_fn=exclude_fn,
1300 adapt_on_names=adapt_on_names,
1301 anonymize_labels=anonymize_labels,
1302 adapt_from_selectables=adapt_from_selectables,
1303 )
1305 self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
1306 if self.include_fn or self.exclude_fn:
1307 self.columns = self._IncludeExcludeMapping(self, self.columns)
1308 self.adapt_required = adapt_required
1309 self.allow_label_resolve = allow_label_resolve
1310 self._wrap = None
1312 class _IncludeExcludeMapping:
1313 def __init__(self, parent, columns):
1314 self.parent = parent
1315 self.columns = columns
1317 def __getitem__(self, key):
1318 if (
1319 self.parent.include_fn and not self.parent.include_fn(key)
1320 ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
1321 if self.parent._wrap:
1322 return self.parent._wrap.columns[key]
1323 else:
1324 return key
1325 return self.columns[key]
1327 def wrap(self, adapter):
1328 ac = copy.copy(self)
1329 ac._wrap = adapter
1330 ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore
1331 if ac.include_fn or ac.exclude_fn:
1332 ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
1334 return ac
1336 @overload
1337 def traverse(self, obj: Literal[None]) -> None: ...
1339 @overload
1340 def traverse(self, obj: _ET) -> _ET: ...
1342 def traverse(
1343 self, obj: Optional[ExternallyTraversible]
1344 ) -> Optional[ExternallyTraversible]:
1345 return self.columns[obj]
1347 def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
1348 assert isinstance(visitor, ColumnAdapter)
1350 return super().chain(visitor)
1352 if TYPE_CHECKING:
1354 @property
1355 def visitor_iterator(self) -> Iterator[ColumnAdapter]: ...
1357 adapt_clause = traverse
1358 adapt_list = ClauseAdapter.copy_and_process
1360 def adapt_check_present(self, col: _ET) -> Optional[_ET]:
1361 newcol = self.columns[col]
1363 if newcol is col and self._corresponding_column(col, True) is None:
1364 return None
1366 return newcol
1368 def _locate_col(
1369 self, col: ColumnElement[Any]
1370 ) -> Optional[ColumnElement[Any]]:
1371 # both replace and traverse() are overly complicated for what
1372 # we are doing here and we would do better to have an inlined
1373 # version that doesn't build up as much overhead. the issue is that
1374 # sometimes the lookup does in fact have to adapt the insides of
1375 # say a labeled scalar subquery. However, if the object is an
1376 # Immutable, i.e. Column objects, we can skip the "clone" /
1377 # "copy internals" part since those will be no-ops in any case.
1378 # additionally we want to catch singleton objects null/true/false
1379 # and make sure they are adapted as well here.
1381 if col._is_immutable:
1382 for vis in self.visitor_iterator:
1383 c = vis.replace(col, _include_singleton_constants=True)
1384 if c is not None:
1385 break
1386 else:
1387 c = col
1388 else:
1389 c = ClauseAdapter.traverse(self, col)
1391 if self._wrap:
1392 c2 = self._wrap._locate_col(c)
1393 if c2 is not None:
1394 c = c2
1396 if self.adapt_required and c is col:
1397 return None
1399 # allow_label_resolve is consumed by one case for joined eager loading
1400 # as part of its logic to prevent its own columns from being affected
1401 # by .order_by(). Before full typing were applied to the ORM, this
1402 # logic would set this attribute on the incoming object (which is
1403 # typically a column, but we have a test for it being a non-column
1404 # object) if no column were found. While this seemed to
1405 # have no negative effects, this adjustment should only occur on the
1406 # new column which is assumed to be local to an adapted selectable.
1407 if c is not col:
1408 c._allow_label_resolve = self.allow_label_resolve
1410 return c
1413def _offset_or_limit_clause(
1414 element: _LimitOffsetType,
1415 name: Optional[str] = None,
1416 type_: Optional[_TypeEngineArgument[int]] = None,
1417) -> ColumnElement[int]:
1418 """Convert the given value to an "offset or limit" clause.
1420 This handles incoming integers and converts to an expression; if
1421 an expression is already given, it is passed through.
1423 """
1424 return coercions.expect(
1425 roles.LimitOffsetRole, element, name=name, type_=type_
1426 )
1429def _offset_or_limit_clause_asint_if_possible(
1430 clause: _LimitOffsetType,
1431) -> _LimitOffsetType:
1432 """Return the offset or limit clause as a simple integer if possible,
1433 else return the clause.
1435 """
1436 if clause is None:
1437 return None
1438 if hasattr(clause, "_limit_offset_value"):
1439 value = clause._limit_offset_value
1440 return util.asint(value)
1441 else:
1442 return clause
1445def _make_slice(
1446 limit_clause: _LimitOffsetType,
1447 offset_clause: _LimitOffsetType,
1448 start: int,
1449 stop: int,
1450) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]:
1451 """Compute LIMIT/OFFSET in terms of slice start/end"""
1453 # for calculated limit/offset, try to do the addition of
1454 # values to offset in Python, however if a SQL clause is present
1455 # then the addition has to be on the SQL side.
1457 # TODO: typing is finding a few gaps in here, see if they can be
1458 # closed up
1460 if start is not None and stop is not None:
1461 offset_clause = _offset_or_limit_clause_asint_if_possible(
1462 offset_clause
1463 )
1464 if offset_clause is None:
1465 offset_clause = 0
1467 if start != 0:
1468 offset_clause = offset_clause + start # type: ignore
1470 if offset_clause == 0:
1471 offset_clause = None
1472 else:
1473 assert offset_clause is not None
1474 offset_clause = _offset_or_limit_clause(offset_clause)
1476 limit_clause = _offset_or_limit_clause(stop - start)
1478 elif start is None and stop is not None:
1479 limit_clause = _offset_or_limit_clause(stop)
1480 elif start is not None and stop is None:
1481 offset_clause = _offset_or_limit_clause_asint_if_possible(
1482 offset_clause
1483 )
1484 if offset_clause is None:
1485 offset_clause = 0
1487 if start != 0:
1488 offset_clause = offset_clause + start
1490 if offset_clause == 0:
1491 offset_clause = None
1492 else:
1493 offset_clause = _offset_or_limit_clause(offset_clause)
1495 return limit_clause, offset_clause