Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy/sql/util.py: 23%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# sql/util.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
9"""High level utilities which build upon other modules here."""
11from __future__ import annotations
13from collections import deque
14import copy
15from itertools import chain
16import typing
17from typing import AbstractSet
18from typing import Any
19from typing import Callable
20from typing import cast
21from typing import Collection
22from typing import Dict
23from typing import Iterable
24from typing import Iterator
25from typing import List
26from typing import Literal
27from typing import Optional
28from typing import overload
29from typing import Protocol
30from typing import Sequence
31from typing import Tuple
32from typing import TYPE_CHECKING
33from typing import TypeVar
34from typing import Union
36from . import coercions
37from . import operators
38from . import roles
39from . import visitors
40from ._typing import is_text_clause
41from .annotation import _deep_annotate as _deep_annotate # noqa: F401
42from .annotation import _deep_deannotate as _deep_deannotate # noqa: F401
43from .annotation import _shallow_annotate as _shallow_annotate # noqa: F401
44from .base import _expand_cloned
45from .base import _from_objects
46from .cache_key import HasCacheKey as HasCacheKey # noqa: F401
47from .ddl import sort_tables as sort_tables # noqa: F401
48from .elements import _find_columns as _find_columns
49from .elements import _label_reference
50from .elements import _textual_label_reference
51from .elements import BindParameter
52from .elements import ClauseElement
53from .elements import ColumnClause
54from .elements import ColumnElement
55from .elements import Grouping
56from .elements import KeyedColumnElement
57from .elements import Label
58from .elements import NamedColumn
59from .elements import Null
60from .elements import UnaryExpression
61from .schema import Column
62from .selectable import Alias
63from .selectable import FromClause
64from .selectable import FromGrouping
65from .selectable import Join
66from .selectable import ScalarSelect
67from .selectable import SelectBase
68from .selectable import TableClause
69from .visitors import _ET
70from .. import exc
71from .. import util
72from ..util.typing import Unpack
74if typing.TYPE_CHECKING:
75 from ._typing import _EquivalentColumnMap
76 from ._typing import _LimitOffsetType
77 from ._typing import _TypeEngineArgument
78 from .elements import AbstractTextClause
79 from .elements import BinaryExpression
80 from .selectable import _JoinTargetElement
81 from .selectable import _SelectIterable
82 from .selectable import Selectable
83 from .visitors import _TraverseCallableType
84 from .visitors import ExternallyTraversible
85 from .visitors import ExternalTraversal
86 from ..engine.interfaces import _AnyExecuteParams
87 from ..engine.interfaces import _AnyMultiExecuteParams
88 from ..engine.interfaces import _AnySingleExecuteParams
89 from ..engine.interfaces import _CoreSingleExecuteParams
90 from ..engine.row import Row
92_CE = TypeVar("_CE", bound="ColumnElement[Any]")
95def join_condition(
96 a: FromClause,
97 b: FromClause,
98 a_subset: Optional[FromClause] = None,
99 consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
100) -> ColumnElement[bool]:
101 """Create a join condition between two tables or selectables.
103 e.g.::
105 join_condition(tablea, tableb)
107 would produce an expression along the lines of::
109 tablea.c.id == tableb.c.tablea_id
111 The join is determined based on the foreign key relationships
112 between the two selectables. If there are multiple ways
113 to join, or no way to join, an error is raised.
115 :param a_subset: An optional expression that is a sub-component
116 of ``a``. An attempt will be made to join to just this sub-component
117 first before looking at the full ``a`` construct, and if found
118 will be successful even if there are other ways to join to ``a``.
119 This allows the "right side" of a join to be passed thereby
120 providing a "natural join".
122 """
123 return Join._join_condition(
124 a,
125 b,
126 a_subset=a_subset,
127 consider_as_foreign_keys=consider_as_foreign_keys,
128 )
131def find_join_source(
132 clauses: List[FromClause], join_to: FromClause
133) -> List[int]:
134 """Given a list of FROM clauses and a selectable,
135 return the first index and element from the list of
136 clauses which can be joined against the selectable. returns
137 None, None if no match is found.
139 e.g.::
141 clause1 = table1.join(table2)
142 clause2 = table4.join(table5)
144 join_to = table2.join(table3)
146 find_join_source([clause1, clause2], join_to) == clause1
148 """
150 selectables = list(_from_objects(join_to))
151 idx = []
152 for i, f in enumerate(clauses):
153 for s in selectables:
154 if f.is_derived_from(s):
155 idx.append(i)
156 return idx
159def find_left_clause_that_matches_given(
160 clauses: Sequence[FromClause], join_from: FromClause
161) -> List[int]:
162 """Given a list of FROM clauses and a selectable,
163 return the indexes from the list of
164 clauses which is derived from the selectable.
166 """
168 selectables = list(_from_objects(join_from))
169 liberal_idx = []
170 for i, f in enumerate(clauses):
171 for s in selectables:
172 # basic check, if f is derived from s.
173 # this can be joins containing a table, or an aliased table
174 # or select statement matching to a table. This check
175 # will match a table to a selectable that is adapted from
176 # that table. With Query, this suits the case where a join
177 # is being made to an adapted entity
178 if f.is_derived_from(s):
179 liberal_idx.append(i)
180 break
182 # in an extremely small set of use cases, a join is being made where
183 # there are multiple FROM clauses where our target table is represented
184 # in more than one, such as embedded or similar. in this case, do
185 # another pass where we try to get a more exact match where we aren't
186 # looking at adaption relationships.
187 if len(liberal_idx) > 1:
188 conservative_idx = []
189 for idx in liberal_idx:
190 f = clauses[idx]
191 for s in selectables:
192 if set(surface_selectables(f)).intersection(
193 surface_selectables(s)
194 ):
195 conservative_idx.append(idx)
196 break
197 if conservative_idx:
198 return conservative_idx
200 return liberal_idx
203def find_left_clause_to_join_from(
204 clauses: Sequence[FromClause],
205 join_to: _JoinTargetElement,
206 onclause: Optional[ColumnElement[Any]],
207) -> List[int]:
208 """Given a list of FROM clauses, a selectable,
209 and optional ON clause, return a list of integer indexes from the
210 clauses list indicating the clauses that can be joined from.
212 The presence of an "onclause" indicates that at least one clause can
213 definitely be joined from; if the list of clauses is of length one
214 and the onclause is given, returns that index. If the list of clauses
215 is more than length one, and the onclause is given, attempts to locate
216 which clauses contain the same columns.
218 """
219 idx = []
220 selectables = set(_from_objects(join_to))
222 # if we are given more than one target clause to join
223 # from, use the onclause to provide a more specific answer.
224 # otherwise, don't try to limit, after all, "ON TRUE" is a valid
225 # on clause
226 if len(clauses) > 1 and onclause is not None:
227 resolve_ambiguity = True
228 cols_in_onclause = _find_columns(onclause)
229 else:
230 resolve_ambiguity = False
231 cols_in_onclause = None
233 for i, f in enumerate(clauses):
234 for s in selectables.difference([f]):
235 if resolve_ambiguity:
236 assert cols_in_onclause is not None
237 if set(f.c).union(s.c).issuperset(cols_in_onclause):
238 idx.append(i)
239 break
240 elif onclause is not None or Join._can_join(f, s):
241 idx.append(i)
242 break
244 if len(idx) > 1:
245 # this is the same "hide froms" logic from
246 # Selectable._get_display_froms
247 toremove = set(
248 chain(*[_expand_cloned(f._hide_froms) for f in clauses])
249 )
250 idx = [i for i in idx if clauses[i] not in toremove]
252 # onclause was given and none of them resolved, so assume
253 # all indexes can match
254 if not idx and onclause is not None:
255 return list(range(len(clauses)))
256 else:
257 return idx
260def visit_binary_product(
261 fn: Callable[
262 [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
263 ],
264 expr: ColumnElement[Any],
265) -> None:
266 """Produce a traversal of the given expression, delivering
267 column comparisons to the given function.
269 The function is of the form::
271 def my_fn(binary, left, right): ...
273 For each binary expression located which has a
274 comparison operator, the product of "left" and
275 "right" will be delivered to that function,
276 in terms of that binary.
278 Hence an expression like::
280 and_((a + b) == q + func.sum(e + f), j == r)
282 would have the traversal:
284 .. sourcecode:: text
286 a <eq> q
287 a <eq> e
288 a <eq> f
289 b <eq> q
290 b <eq> e
291 b <eq> f
292 j <eq> r
294 That is, every combination of "left" and
295 "right" that doesn't further contain
296 a binary comparison is passed as pairs.
298 """
299 stack: List[BinaryExpression[Any]] = []
301 def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
302 if isinstance(element, ScalarSelect):
303 # we don't want to dig into correlated subqueries,
304 # those are just column elements by themselves
305 yield element
306 elif element.__visit_name__ == "binary" and operators.is_comparison(
307 element.operator # type: ignore
308 ):
309 stack.insert(0, element) # type: ignore
310 for l in visit(element.left): # type: ignore
311 for r in visit(element.right): # type: ignore
312 fn(stack[0], l, r)
313 stack.pop(0)
314 for elem in element.get_children():
315 visit(elem)
316 else:
317 if isinstance(element, ColumnClause):
318 yield element
319 for elem in element.get_children():
320 yield from visit(elem)
322 list(visit(expr))
323 visit = None # type: ignore # remove gc cycles
326def find_tables(
327 clause: ClauseElement,
328 *,
329 check_columns: bool = False,
330 include_aliases: bool = False,
331 include_joins: bool = False,
332 include_selects: bool = False,
333 include_crud: bool = False,
334) -> List[TableClause]:
335 """locate Table objects within the given expression."""
337 tables: List[TableClause] = []
338 _visitors: Dict[str, _TraverseCallableType[Any]] = {}
340 if include_selects:
341 _visitors["select"] = _visitors["compound_select"] = tables.append
343 if include_joins:
344 _visitors["join"] = tables.append
346 if include_aliases:
347 _visitors["alias"] = _visitors["subquery"] = _visitors[
348 "tablesample"
349 ] = _visitors["lateral"] = tables.append
351 if include_crud:
352 _visitors["insert"] = _visitors["update"] = _visitors["delete"] = (
353 lambda ent: tables.append(ent.table)
354 )
356 if check_columns:
358 def visit_column(column):
359 tables.append(column.table)
361 _visitors["column"] = visit_column
363 _visitors["table"] = tables.append
365 visitors.traverse(clause, {}, _visitors)
366 return tables
369def unwrap_order_by(clause: Any) -> Any:
370 """Break up an 'order by' expression into individual column-expressions,
371 without DESC/ASC/NULLS FIRST/NULLS LAST"""
373 cols = util.column_set()
374 result = []
375 stack = deque([clause])
377 # examples
378 # column -> ASC/DESC == column
379 # column -> ASC/DESC -> label == column
380 # column -> label -> ASC/DESC -> label == column
381 # scalar_select -> label -> ASC/DESC == scalar_select -> label
383 while stack:
384 t = stack.popleft()
385 if isinstance(t, ColumnElement) and (
386 not isinstance(t, UnaryExpression)
387 or not operators.is_ordering_modifier(t.modifier) # type: ignore
388 ):
389 if isinstance(t, Label) and not isinstance(
390 t.element, ScalarSelect
391 ):
392 t = t.element
394 if isinstance(t, Grouping):
395 t = t.element
397 stack.append(t)
398 continue
399 elif isinstance(t, _label_reference):
400 t = t.element
402 stack.append(t)
403 continue
404 if isinstance(t, (_textual_label_reference)):
405 continue
406 if t not in cols:
407 cols.add(t)
408 result.append(t)
410 else:
411 for c in t.get_children():
412 stack.append(c)
413 return result
416def unwrap_label_reference(element):
417 def replace(
418 element: ExternallyTraversible, **kw: Any
419 ) -> Optional[ExternallyTraversible]:
420 if isinstance(element, _label_reference):
421 return element.element
422 elif isinstance(element, _textual_label_reference):
423 assert False, "can't unwrap a textual label reference"
424 return None
426 return visitors.replacement_traverse(element, {}, replace)
429def expand_column_list_from_order_by(collist, order_by):
430 """Given the columns clause and ORDER BY of a selectable,
431 return a list of column expressions that can be added to the collist
432 corresponding to the ORDER BY, without repeating those already
433 in the collist.
435 """
436 cols_already_present = {
437 col.element if col._order_by_label_element is not None else col
438 for col in collist
439 }
441 to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
443 return [col for col in to_look_for if col not in cols_already_present]
446def clause_is_present(clause, search):
447 """Given a target clause and a second to search within, return True
448 if the target is plainly present in the search without any
449 subqueries or aliases involved.
451 Basically descends through Joins.
453 """
455 for elem in surface_selectables(search):
456 if clause == elem: # use == here so that Annotated's compare
457 return True
458 else:
459 return False
462def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
463 if isinstance(clause, Join):
464 yield from tables_from_leftmost(clause.left)
465 yield from tables_from_leftmost(clause.right)
466 elif isinstance(clause, FromGrouping):
467 yield from tables_from_leftmost(clause.element)
468 else:
469 yield clause
472def surface_expressions(clause):
473 stack = [clause]
474 while stack:
475 elem = stack.pop()
476 yield elem
477 if isinstance(elem, ColumnElement):
478 stack.extend(elem.get_children())
481def surface_selectables(clause):
482 stack = [clause]
483 while stack:
484 elem = stack.pop()
485 yield elem
486 if isinstance(elem, Join):
487 stack.extend((elem.left, elem.right))
488 elif isinstance(elem, FromGrouping):
489 stack.append(elem.element)
492def surface_selectables_only(clause: ClauseElement) -> Iterator[ClauseElement]:
493 stack = [clause]
494 while stack:
495 elem = stack.pop()
496 if isinstance(elem, (TableClause, Alias)):
497 yield elem
498 if isinstance(elem, Join):
499 stack.extend((elem.left, elem.right))
500 elif isinstance(elem, FromGrouping):
501 stack.append(elem.element)
502 elif isinstance(elem, ColumnClause):
503 if elem.table is not None:
504 stack.append(elem.table)
505 else:
506 yield elem
507 elif elem is not None:
508 yield elem
511def extract_first_column_annotation(column, annotation_name):
512 filter_ = (FromGrouping, SelectBase)
514 stack = deque([column])
515 while stack:
516 elem = stack.popleft()
517 if annotation_name in elem._annotations:
518 return elem._annotations[annotation_name]
519 for sub in elem.get_children():
520 if isinstance(sub, filter_):
521 continue
522 stack.append(sub)
523 return None
526def selectables_overlap(left: FromClause, right: FromClause) -> bool:
527 """Return True if left/right have some overlapping selectable"""
529 return bool(
530 set(surface_selectables(left)).intersection(surface_selectables(right))
531 )
534def bind_values(clause):
535 """Return an ordered list of "bound" values in the given clause.
537 E.g.::
539 >>> expr = and_(table.c.foo == 5, table.c.foo == 7)
540 >>> bind_values(expr)
541 [5, 7]
542 """
544 v = []
546 def visit_bindparam(bind):
547 v.append(bind.effective_value)
549 visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
550 return v
553def _quote_ddl_expr(element):
554 if isinstance(element, str):
555 element = element.replace("'", "''")
556 return "'%s'" % element
557 else:
558 return repr(element)
561class _repr_base:
562 _LIST: int = 0
563 _TUPLE: int = 1
564 _DICT: int = 2
566 __slots__ = ("max_chars",)
568 max_chars: int
570 def trunc(self, value: Any) -> str:
571 rep = repr(value)
572 lenrep = len(rep)
573 if lenrep > self.max_chars:
574 segment_length = self.max_chars // 2
575 rep = (
576 rep[0:segment_length]
577 + (
578 " ... (%d characters truncated) ... "
579 % (lenrep - self.max_chars)
580 )
581 + rep[-segment_length:]
582 )
583 return rep
586def _repr_single_value(value):
587 rp = _repr_base()
588 rp.max_chars = 300
589 return rp.trunc(value)
592class _repr_row(_repr_base):
593 """Provide a string view of a row."""
595 __slots__ = ("row",)
597 def __init__(
598 self, row: Row[Unpack[Tuple[Any, ...]]], max_chars: int = 300
599 ):
600 self.row = row
601 self.max_chars = max_chars
603 def __repr__(self) -> str:
604 trunc = self.trunc
605 return "(%s%s)" % (
606 ", ".join(trunc(value) for value in self.row),
607 "," if len(self.row) == 1 else "",
608 )
611class _long_statement(str):
612 def __str__(self) -> str:
613 lself = len(self)
614 if lself > 500:
615 lleft = 250
616 lright = 100
617 trunc = lself - lleft - lright
618 return (
619 f"{self[0:lleft]} ... {trunc} "
620 f"characters truncated ... {self[-lright:]}"
621 )
622 else:
623 return str.__str__(self)
626class _repr_params(_repr_base):
627 """Provide a string view of bound parameters.
629 Truncates display to a given number of 'multi' parameter sets,
630 as well as long values to a given number of characters.
632 """
634 __slots__ = "params", "batches", "ismulti", "max_params"
636 def __init__(
637 self,
638 params: Optional[_AnyExecuteParams],
639 batches: int,
640 max_params: int = 100,
641 max_chars: int = 300,
642 ismulti: Optional[bool] = None,
643 ):
644 self.params = params
645 self.ismulti = ismulti
646 self.batches = batches
647 self.max_chars = max_chars
648 self.max_params = max_params
650 def __repr__(self) -> str:
651 if self.ismulti is None:
652 return self.trunc(self.params)
654 if isinstance(self.params, list):
655 typ = self._LIST
657 elif isinstance(self.params, tuple):
658 typ = self._TUPLE
659 elif isinstance(self.params, dict):
660 typ = self._DICT
661 else:
662 return self.trunc(self.params)
664 if self.ismulti:
665 multi_params = cast(
666 "_AnyMultiExecuteParams",
667 self.params,
668 )
670 if len(self.params) > self.batches:
671 msg = (
672 " ... displaying %i of %i total bound parameter sets ... "
673 )
674 return " ".join(
675 (
676 self._repr_multi(
677 multi_params[: self.batches - 2],
678 typ,
679 )[0:-1],
680 msg % (self.batches, len(self.params)),
681 self._repr_multi(multi_params[-2:], typ)[1:],
682 )
683 )
684 else:
685 return self._repr_multi(multi_params, typ)
686 else:
687 return self._repr_params(
688 cast(
689 "_AnySingleExecuteParams",
690 self.params,
691 ),
692 typ,
693 )
695 def _repr_multi(
696 self,
697 multi_params: _AnyMultiExecuteParams,
698 typ: int,
699 ) -> str:
700 if multi_params:
701 if isinstance(multi_params[0], list):
702 elem_type = self._LIST
703 elif isinstance(multi_params[0], tuple):
704 elem_type = self._TUPLE
705 elif isinstance(multi_params[0], dict):
706 elem_type = self._DICT
707 else:
708 assert False, "Unknown parameter type %s" % (
709 type(multi_params[0])
710 )
712 elements = ", ".join(
713 self._repr_params(params, elem_type) for params in multi_params
714 )
715 else:
716 elements = ""
718 if typ == self._LIST:
719 return "[%s]" % elements
720 else:
721 return "(%s)" % elements
723 def _get_batches(self, params: Iterable[Any]) -> Any:
724 lparams = list(params)
725 lenparams = len(lparams)
726 if lenparams > self.max_params:
727 lleft = self.max_params // 2
728 return (
729 lparams[0:lleft],
730 lparams[-lleft:],
731 lenparams - self.max_params,
732 )
733 else:
734 return lparams, None, None
736 def _repr_params(
737 self,
738 params: _AnySingleExecuteParams,
739 typ: int,
740 ) -> str:
741 if typ is self._DICT:
742 return self._repr_param_dict(
743 cast("_CoreSingleExecuteParams", params)
744 )
745 elif typ is self._TUPLE:
746 return self._repr_param_tuple(cast("Sequence[Any]", params))
747 else:
748 return self._repr_param_list(params)
750 def _repr_param_dict(self, params: _CoreSingleExecuteParams) -> str:
751 trunc = self.trunc
752 (
753 items_first_batch,
754 items_second_batch,
755 trunclen,
756 ) = self._get_batches(params.items())
758 if items_second_batch:
759 text = "{%s" % (
760 ", ".join(
761 f"{key!r}: {trunc(value)}"
762 for key, value in items_first_batch
763 )
764 )
765 text += f" ... {trunclen} parameters truncated ... "
766 text += "%s}" % (
767 ", ".join(
768 f"{key!r}: {trunc(value)}"
769 for key, value in items_second_batch
770 )
771 )
772 else:
773 text = "{%s}" % (
774 ", ".join(
775 f"{key!r}: {trunc(value)}"
776 for key, value in items_first_batch
777 )
778 )
779 return text
781 def _repr_param_tuple(self, params: Sequence[Any]) -> str:
782 trunc = self.trunc
784 (
785 items_first_batch,
786 items_second_batch,
787 trunclen,
788 ) = self._get_batches(params)
790 if items_second_batch:
791 text = "(%s" % (
792 ", ".join(trunc(value) for value in items_first_batch)
793 )
794 text += f" ... {trunclen} parameters truncated ... "
795 text += "%s)" % (
796 ", ".join(trunc(value) for value in items_second_batch),
797 )
798 else:
799 text = "(%s%s)" % (
800 ", ".join(trunc(value) for value in items_first_batch),
801 "," if len(items_first_batch) == 1 else "",
802 )
803 return text
805 def _repr_param_list(self, params: _AnySingleExecuteParams) -> str:
806 trunc = self.trunc
807 (
808 items_first_batch,
809 items_second_batch,
810 trunclen,
811 ) = self._get_batches(params)
813 if items_second_batch:
814 text = "[%s" % (
815 ", ".join(trunc(value) for value in items_first_batch)
816 )
817 text += f" ... {trunclen} parameters truncated ... "
818 text += "%s]" % (
819 ", ".join(trunc(value) for value in items_second_batch)
820 )
821 else:
822 text = "[%s]" % (
823 ", ".join(trunc(value) for value in items_first_batch)
824 )
825 return text
828def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
829 """given criterion containing bind params, convert selected elements
830 to IS NULL.
832 """
834 def visit_binary(binary):
835 if (
836 isinstance(binary.left, BindParameter)
837 and binary.left._identifying_key in nulls
838 ):
839 # reverse order if the NULL is on the left side
840 binary.left = binary.right
841 binary.right = Null()
842 binary.operator = operators.is_
843 binary.negate = operators.is_not
844 elif (
845 isinstance(binary.right, BindParameter)
846 and binary.right._identifying_key in nulls
847 ):
848 binary.right = Null()
849 binary.operator = operators.is_
850 binary.negate = operators.is_not
852 return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
855def splice_joins(
856 left: Optional[FromClause],
857 right: Optional[FromClause],
858 stop_on: Optional[FromClause] = None,
859) -> Optional[FromClause]:
860 if left is None:
861 return right
863 stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)]
865 adapter = ClauseAdapter(left)
866 ret = None
867 while stack:
868 right, prevright = stack.pop()
869 if isinstance(right, Join) and right is not stop_on:
870 right = right._clone()
871 right.onclause = adapter.traverse(right.onclause)
872 stack.append((right.left, right))
873 else:
874 right = adapter.traverse(right)
875 if prevright is not None:
876 assert right is not None
877 prevright.left = right
878 if ret is None:
879 ret = right
881 return ret
884@overload
885def reduce_columns(
886 columns: Iterable[ColumnElement[Any]],
887 *clauses: Optional[ClauseElement],
888 **kw: bool,
889) -> Sequence[ColumnElement[Any]]: ...
892@overload
893def reduce_columns(
894 columns: _SelectIterable,
895 *clauses: Optional[ClauseElement],
896 **kw: bool,
897) -> Sequence[Union[ColumnElement[Any], AbstractTextClause]]: ...
900def reduce_columns(
901 columns: _SelectIterable,
902 *clauses: Optional[ClauseElement],
903 **kw: bool,
904) -> Collection[Union[ColumnElement[Any], AbstractTextClause]]:
905 r"""given a list of columns, return a 'reduced' set based on natural
906 equivalents.
908 the set is reduced to the smallest list of columns which have no natural
909 equivalent present in the list. A "natural equivalent" means that two
910 columns will ultimately represent the same value because they are related
911 by a foreign key.
913 \*clauses is an optional list of join clauses which will be traversed
914 to further identify columns that are "equivalent".
916 \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
917 whose tables are not yet configured, or columns that aren't yet present.
919 This function is primarily used to determine the most minimal "primary
920 key" from a selectable, by reducing the set of primary key columns present
921 in the selectable to just those that are not repeated.
923 """
924 ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
925 only_synonyms = kw.pop("only_synonyms", False)
927 column_set = util.OrderedSet(columns)
928 cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
929 c for c in column_set if is_text_clause(c) # type: ignore
930 )
932 omit = util.column_set()
933 for col in cset_no_text:
934 for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
935 for c in cset_no_text:
936 if c is col:
937 continue
938 try:
939 fk_col = fk.column
940 except exc.NoReferencedColumnError:
941 # TODO: add specific coverage here
942 # to test/sql/test_selectable ReduceTest
943 if ignore_nonexistent_tables:
944 continue
945 else:
946 raise
947 except exc.NoReferencedTableError:
948 # TODO: add specific coverage here
949 # to test/sql/test_selectable ReduceTest
950 if ignore_nonexistent_tables:
951 continue
952 else:
953 raise
954 if fk_col.shares_lineage(c) and (
955 not only_synonyms or c.name == col.name
956 ):
957 omit.add(col)
958 break
960 if clauses:
962 def visit_binary(binary):
963 if binary.operator == operators.eq:
964 cols = util.column_set(
965 chain(
966 *[c.proxy_set for c in cset_no_text.difference(omit)]
967 )
968 )
969 if binary.left in cols and binary.right in cols:
970 for c in reversed(cset_no_text):
971 if c.shares_lineage(binary.right) and (
972 not only_synonyms or c.name == binary.left.name
973 ):
974 omit.add(c)
975 break
977 for clause in clauses:
978 if clause is not None:
979 visitors.traverse(clause, {}, {"binary": visit_binary})
981 return column_set.difference(omit)
984def criterion_as_pairs(
985 expression,
986 consider_as_foreign_keys=None,
987 consider_as_referenced_keys=None,
988 any_operator=False,
989):
990 """traverse an expression and locate binary criterion pairs."""
992 if consider_as_foreign_keys and consider_as_referenced_keys:
993 raise exc.ArgumentError(
994 "Can only specify one of "
995 "'consider_as_foreign_keys' or "
996 "'consider_as_referenced_keys'"
997 )
999 def col_is(a, b):
1000 # return a is b
1001 return a.compare(b)
1003 def visit_binary(binary):
1004 if not any_operator and binary.operator is not operators.eq:
1005 return
1006 if not isinstance(binary.left, ColumnElement) or not isinstance(
1007 binary.right, ColumnElement
1008 ):
1009 return
1011 if consider_as_foreign_keys:
1012 if binary.left in consider_as_foreign_keys and (
1013 col_is(binary.right, binary.left)
1014 or binary.right not in consider_as_foreign_keys
1015 ):
1016 pairs.append((binary.right, binary.left))
1017 elif binary.right in consider_as_foreign_keys and (
1018 col_is(binary.left, binary.right)
1019 or binary.left not in consider_as_foreign_keys
1020 ):
1021 pairs.append((binary.left, binary.right))
1022 elif consider_as_referenced_keys:
1023 if binary.left in consider_as_referenced_keys and (
1024 col_is(binary.right, binary.left)
1025 or binary.right not in consider_as_referenced_keys
1026 ):
1027 pairs.append((binary.left, binary.right))
1028 elif binary.right in consider_as_referenced_keys and (
1029 col_is(binary.left, binary.right)
1030 or binary.left not in consider_as_referenced_keys
1031 ):
1032 pairs.append((binary.right, binary.left))
1033 else:
1034 if isinstance(binary.left, Column) and isinstance(
1035 binary.right, Column
1036 ):
1037 if binary.left.references(binary.right):
1038 pairs.append((binary.right, binary.left))
1039 elif binary.right.references(binary.left):
1040 pairs.append((binary.left, binary.right))
1042 pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = []
1043 visitors.traverse(expression, {}, {"binary": visit_binary})
1044 return pairs
1047class ClauseAdapter(visitors.ReplacingExternalTraversal):
1048 """Clones and modifies clauses based on column correspondence.
1050 E.g.::
1052 table1 = Table(
1053 "sometable",
1054 metadata,
1055 Column("col1", Integer),
1056 Column("col2", Integer),
1057 )
1058 table2 = Table(
1059 "someothertable",
1060 metadata,
1061 Column("col1", Integer),
1062 Column("col2", Integer),
1063 )
1065 condition = table1.c.col1 == table2.c.col1
1067 make an alias of table1::
1069 s = table1.alias("foo")
1071 calling ``ClauseAdapter(s).traverse(condition)`` converts
1072 condition to read::
1074 s.c.col1 == table2.c.col1
1076 """
1078 __slots__ = (
1079 "__traverse_options__",
1080 "selectable",
1081 "include_fn",
1082 "exclude_fn",
1083 "equivalents",
1084 "adapt_on_names",
1085 "adapt_from_selectables",
1086 )
1088 def __init__(
1089 self,
1090 selectable: Selectable,
1091 equivalents: Optional[_EquivalentColumnMap] = None,
1092 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1093 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1094 adapt_on_names: bool = False,
1095 anonymize_labels: bool = False,
1096 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1097 ):
1098 self.__traverse_options__ = {
1099 "stop_on": [selectable],
1100 "anonymize_labels": anonymize_labels,
1101 }
1102 self.selectable = selectable
1103 self.include_fn = include_fn
1104 self.exclude_fn = exclude_fn
1105 self.equivalents = util.column_dict(equivalents or {})
1106 self.adapt_on_names = adapt_on_names
1107 self.adapt_from_selectables = adapt_from_selectables
1109 if TYPE_CHECKING:
1111 @overload
1112 def traverse(self, obj: Literal[None]) -> None: ...
1114 # note this specializes the ReplacingExternalTraversal.traverse()
1115 # method to state
1116 # that we will return the same kind of ExternalTraversal object as
1117 # we were given. This is probably not 100% true, such as it's
1118 # possible for us to swap out Alias for Table at the top level.
1119 # Ideally there could be overloads specific to ColumnElement and
1120 # FromClause but Mypy is not accepting those as compatible with
1121 # the base ReplacingExternalTraversal
1122 @overload
1123 def traverse(self, obj: _ET) -> _ET: ...
1125 def traverse(
1126 self, obj: Optional[ExternallyTraversible]
1127 ) -> Optional[ExternallyTraversible]: ...
1129 def _corresponding_column(
1130 self, col, require_embedded, _seen=util.EMPTY_SET
1131 ):
1132 newcol = self.selectable.corresponding_column(
1133 col, require_embedded=require_embedded
1134 )
1135 if newcol is None and col in self.equivalents and col not in _seen:
1136 for equiv in self.equivalents[col]:
1137 newcol = self._corresponding_column(
1138 equiv,
1139 require_embedded=require_embedded,
1140 _seen=_seen.union([col]),
1141 )
1142 if newcol is not None:
1143 return newcol
1145 if (
1146 self.adapt_on_names
1147 and newcol is None
1148 and isinstance(col, NamedColumn)
1149 ):
1150 newcol = self.selectable.exported_columns.get(col.name)
1151 return newcol
1153 @util.preload_module("sqlalchemy.sql.functions")
1154 def replace(
1155 self, col: _ET, _include_singleton_constants: bool = False
1156 ) -> Optional[_ET]:
1157 functions = util.preloaded.sql_functions
1159 # TODO: cython candidate
1161 if self.include_fn and not self.include_fn(col): # type: ignore
1162 return None
1163 elif self.exclude_fn and self.exclude_fn(col): # type: ignore
1164 return None
1166 if isinstance(col, FromClause) and not isinstance(
1167 col, functions.FunctionElement
1168 ):
1169 if self.selectable.is_derived_from(col):
1170 if self.adapt_from_selectables:
1171 for adp in self.adapt_from_selectables:
1172 if adp.is_derived_from(col):
1173 break
1174 else:
1175 return None
1176 return self.selectable # type: ignore
1177 elif isinstance(col, Alias) and isinstance(
1178 col.element, TableClause
1179 ):
1180 # we are a SELECT statement and not derived from an alias of a
1181 # table (which nonetheless may be a table our SELECT derives
1182 # from), so return the alias to prevent further traversal
1183 # or
1184 # we are an alias of a table and we are not derived from an
1185 # alias of a table (which nonetheless may be the same table
1186 # as ours) so, same thing
1187 return col
1188 else:
1189 # other cases where we are a selectable and the element
1190 # is another join or selectable that contains a table which our
1191 # selectable derives from, that we want to process
1192 return None
1194 elif not isinstance(col, ColumnElement):
1195 return None
1196 elif not _include_singleton_constants and col._is_singleton_constant:
1197 # dont swap out NULL, TRUE, FALSE for a label name
1198 # in a SQL statement that's being rewritten,
1199 # leave them as the constant. This is first noted in #6259,
1200 # however the logic to check this moved here as of #7154 so that
1201 # it is made specific to SQL rewriting and not all column
1202 # correspondence
1204 return None
1206 if "adapt_column" in col._annotations:
1207 col = col._annotations["adapt_column"]
1209 if TYPE_CHECKING:
1210 assert isinstance(col, KeyedColumnElement)
1212 if self.adapt_from_selectables and col not in self.equivalents:
1213 for adp in self.adapt_from_selectables:
1214 if adp.c.corresponding_column(col, False) is not None:
1215 break
1216 else:
1217 return None
1219 if TYPE_CHECKING:
1220 assert isinstance(col, KeyedColumnElement)
1222 return self._corresponding_column( # type: ignore
1223 col, require_embedded=True
1224 )
1227class _ColumnLookup(Protocol):
1228 @overload
1229 def __getitem__(self, key: None) -> None: ...
1231 @overload
1232 def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: ...
1234 @overload
1235 def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: ...
1237 @overload
1238 def __getitem__(self, key: _ET) -> _ET: ...
1240 def __getitem__(self, key: Any) -> Any: ...
1243class ColumnAdapter(ClauseAdapter):
1244 """Extends ClauseAdapter with extra utility functions.
1246 Key aspects of ColumnAdapter include:
1248 * Expressions that are adapted are stored in a persistent
1249 .columns collection; so that an expression E adapted into
1250 an expression E1, will return the same object E1 when adapted
1251 a second time. This is important in particular for things like
1252 Label objects that are anonymized, so that the ColumnAdapter can
1253 be used to present a consistent "adapted" view of things.
1255 * Exclusion of items from the persistent collection based on
1256 include/exclude rules, but also independent of hash identity.
1257 This because "annotated" items all have the same hash identity as their
1258 parent.
1260 * "wrapping" capability is added, so that the replacement of an expression
1261 E can proceed through a series of adapters. This differs from the
1262 visitor's "chaining" feature in that the resulting object is passed
1263 through all replacing functions unconditionally, rather than stopping
1264 at the first one that returns non-None.
1266 * An adapt_required option, used by eager loading to indicate that
1267 We don't trust a result row column that is not translated.
1268 This is to prevent a column from being interpreted as that
1269 of the child row in a self-referential scenario, see
1270 inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
1272 """
1274 __slots__ = (
1275 "columns",
1276 "adapt_required",
1277 "allow_label_resolve",
1278 "_wrap",
1279 "__weakref__",
1280 )
1282 columns: _ColumnLookup
1284 def __init__(
1285 self,
1286 selectable: Selectable,
1287 equivalents: Optional[_EquivalentColumnMap] = None,
1288 adapt_required: bool = False,
1289 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1290 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1291 adapt_on_names: bool = False,
1292 allow_label_resolve: bool = True,
1293 anonymize_labels: bool = False,
1294 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1295 ):
1296 super().__init__(
1297 selectable,
1298 equivalents,
1299 include_fn=include_fn,
1300 exclude_fn=exclude_fn,
1301 adapt_on_names=adapt_on_names,
1302 anonymize_labels=anonymize_labels,
1303 adapt_from_selectables=adapt_from_selectables,
1304 )
1306 self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
1307 if self.include_fn or self.exclude_fn:
1308 self.columns = self._IncludeExcludeMapping(self, self.columns)
1309 self.adapt_required = adapt_required
1310 self.allow_label_resolve = allow_label_resolve
1311 self._wrap = None
1313 class _IncludeExcludeMapping:
1314 def __init__(self, parent, columns):
1315 self.parent = parent
1316 self.columns = columns
1318 def __getitem__(self, key):
1319 if (
1320 self.parent.include_fn and not self.parent.include_fn(key)
1321 ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
1322 if self.parent._wrap:
1323 return self.parent._wrap.columns[key]
1324 else:
1325 return key
1326 return self.columns[key]
1328 def wrap(self, adapter):
1329 ac = copy.copy(self)
1330 ac._wrap = adapter
1331 ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore
1332 if ac.include_fn or ac.exclude_fn:
1333 ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
1335 return ac
1337 @overload
1338 def traverse(self, obj: Literal[None]) -> None: ...
1340 @overload
1341 def traverse(self, obj: _ET) -> _ET: ...
1343 def traverse(
1344 self, obj: Optional[ExternallyTraversible]
1345 ) -> Optional[ExternallyTraversible]:
1346 return self.columns[obj]
1348 def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
1349 assert isinstance(visitor, ColumnAdapter)
1351 return super().chain(visitor)
1353 if TYPE_CHECKING:
1355 @property
1356 def visitor_iterator(self) -> Iterator[ColumnAdapter]: ...
1358 adapt_clause = traverse
1359 adapt_list = ClauseAdapter.copy_and_process
1361 def adapt_check_present(self, col: _ET) -> Optional[_ET]:
1362 newcol = self.columns[col]
1364 if newcol is col and self._corresponding_column(col, True) is None:
1365 return None
1367 return newcol
1369 def _locate_col(
1370 self, col: ColumnElement[Any]
1371 ) -> Optional[ColumnElement[Any]]:
1372 # both replace and traverse() are overly complicated for what
1373 # we are doing here and we would do better to have an inlined
1374 # version that doesn't build up as much overhead. the issue is that
1375 # sometimes the lookup does in fact have to adapt the insides of
1376 # say a labeled scalar subquery. However, if the object is an
1377 # Immutable, i.e. Column objects, we can skip the "clone" /
1378 # "copy internals" part since those will be no-ops in any case.
1379 # additionally we want to catch singleton objects null/true/false
1380 # and make sure they are adapted as well here.
1382 if col._is_immutable:
1383 for vis in self.visitor_iterator:
1384 c = vis.replace(col, _include_singleton_constants=True)
1385 if c is not None:
1386 break
1387 else:
1388 c = col
1389 else:
1390 c = ClauseAdapter.traverse(self, col)
1392 if self._wrap:
1393 c2 = self._wrap._locate_col(c)
1394 if c2 is not None:
1395 c = c2
1397 if self.adapt_required and c is col:
1398 return None
1400 # allow_label_resolve is consumed by one case for joined eager loading
1401 # as part of its logic to prevent its own columns from being affected
1402 # by .order_by(). Before full typing were applied to the ORM, this
1403 # logic would set this attribute on the incoming object (which is
1404 # typically a column, but we have a test for it being a non-column
1405 # object) if no column were found. While this seemed to
1406 # have no negative effects, this adjustment should only occur on the
1407 # new column which is assumed to be local to an adapted selectable.
1408 if c is not col:
1409 c._allow_label_resolve = self.allow_label_resolve
1411 return c
1414def _offset_or_limit_clause(
1415 element: _LimitOffsetType,
1416 name: Optional[str] = None,
1417 type_: Optional[_TypeEngineArgument[int]] = None,
1418) -> ColumnElement[int]:
1419 """Convert the given value to an "offset or limit" clause.
1421 This handles incoming integers and converts to an expression; if
1422 an expression is already given, it is passed through.
1424 """
1425 return coercions.expect(
1426 roles.LimitOffsetRole, element, name=name, type_=type_
1427 )
1430def _offset_or_limit_clause_asint_if_possible(
1431 clause: _LimitOffsetType,
1432) -> _LimitOffsetType:
1433 """Return the offset or limit clause as a simple integer if possible,
1434 else return the clause.
1436 """
1437 if clause is None:
1438 return None
1439 if hasattr(clause, "_limit_offset_value"):
1440 value = clause._limit_offset_value
1441 return util.asint(value)
1442 else:
1443 return clause
1446def _make_slice(
1447 limit_clause: _LimitOffsetType,
1448 offset_clause: _LimitOffsetType,
1449 start: int,
1450 stop: int,
1451) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]:
1452 """Compute LIMIT/OFFSET in terms of slice start/end"""
1454 # for calculated limit/offset, try to do the addition of
1455 # values to offset in Python, however if a SQL clause is present
1456 # then the addition has to be on the SQL side.
1458 # TODO: typing is finding a few gaps in here, see if they can be
1459 # closed up
1461 if start is not None and stop is not None:
1462 offset_clause = _offset_or_limit_clause_asint_if_possible(
1463 offset_clause
1464 )
1465 if offset_clause is None:
1466 offset_clause = 0
1468 if start != 0:
1469 offset_clause = offset_clause + start # type: ignore
1471 if offset_clause == 0:
1472 offset_clause = None
1473 else:
1474 assert offset_clause is not None
1475 offset_clause = _offset_or_limit_clause(offset_clause)
1477 limit_clause = _offset_or_limit_clause(stop - start)
1479 elif start is None and stop is not None:
1480 limit_clause = _offset_or_limit_clause(stop)
1481 elif start is not None and stop is None:
1482 offset_clause = _offset_or_limit_clause_asint_if_possible(
1483 offset_clause
1484 )
1485 if offset_clause is None:
1486 offset_clause = 0
1488 if start != 0:
1489 offset_clause = offset_clause + start
1491 if offset_clause == 0:
1492 offset_clause = None
1493 else:
1494 offset_clause = _offset_or_limit_clause(offset_clause)
1496 return limit_clause, offset_clause