Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.11/site-packages/sqlalchemy/sql/util.py: 26%
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
Shortcuts on this page
r m x toggle line displays
j k next/prev highlighted chunk
0 (zero) top of page
1 (one) first highlighted chunk
1# sql/util.py
2# Copyright (C) 2005-2026 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
9"""High level utilities which build upon other modules here."""
11from __future__ import annotations
13from collections import deque
14import copy
15from itertools import chain
16import typing
17from typing import AbstractSet
18from typing import Any
19from typing import Callable
20from typing import cast
21from typing import Collection
22from typing import Dict
23from typing import Iterable
24from typing import Iterator
25from typing import List
26from typing import Optional
27from typing import overload
28from typing import Sequence
29from typing import Tuple
30from typing import TYPE_CHECKING
31from typing import TypeVar
32from typing import Union
34from . import coercions
35from . import operators
36from . import roles
37from . import visitors
38from ._typing import is_text_clause
39from .annotation import _deep_annotate as _deep_annotate # noqa: F401
40from .annotation import _deep_deannotate as _deep_deannotate # noqa: F401
41from .annotation import _shallow_annotate as _shallow_annotate # noqa: F401
42from .base import _expand_cloned
43from .base import _from_objects
44from .cache_key import HasCacheKey as HasCacheKey # noqa: F401
45from .ddl import sort_tables as sort_tables # noqa: F401
46from .elements import _find_columns as _find_columns
47from .elements import _label_reference
48from .elements import _textual_label_reference
49from .elements import BindParameter
50from .elements import ClauseElement
51from .elements import ColumnClause
52from .elements import ColumnElement
53from .elements import Grouping
54from .elements import KeyedColumnElement
55from .elements import Label
56from .elements import NamedColumn
57from .elements import Null
58from .elements import UnaryExpression
59from .schema import Column
60from .selectable import Alias
61from .selectable import FromClause
62from .selectable import FromGrouping
63from .selectable import Join
64from .selectable import ScalarSelect
65from .selectable import SelectBase
66from .selectable import TableClause
67from .visitors import _ET
68from .. import exc
69from .. import util
70from ..util.typing import Literal
71from ..util.typing import Protocol
73if typing.TYPE_CHECKING:
74 from ._typing import _EquivalentColumnMap
75 from ._typing import _LimitOffsetType
76 from ._typing import _TypeEngineArgument
77 from .elements import BinaryExpression
78 from .elements import TextClause
79 from .selectable import _JoinTargetElement
80 from .selectable import _SelectIterable
81 from .selectable import Selectable
82 from .visitors import _TraverseCallableType
83 from .visitors import ExternallyTraversible
84 from .visitors import ExternalTraversal
85 from ..engine.interfaces import _AnyExecuteParams
86 from ..engine.interfaces import _AnyMultiExecuteParams
87 from ..engine.interfaces import _AnySingleExecuteParams
88 from ..engine.interfaces import _CoreSingleExecuteParams
89 from ..engine.row import Row
91_CE = TypeVar("_CE", bound="ColumnElement[Any]")
94def join_condition(
95 a: FromClause,
96 b: FromClause,
97 a_subset: Optional[FromClause] = None,
98 consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
99) -> ColumnElement[bool]:
100 """Create a join condition between two tables or selectables.
102 e.g.::
104 join_condition(tablea, tableb)
106 would produce an expression along the lines of::
108 tablea.c.id == tableb.c.tablea_id
110 The join is determined based on the foreign key relationships
111 between the two selectables. If there are multiple ways
112 to join, or no way to join, an error is raised.
114 :param a_subset: An optional expression that is a sub-component
115 of ``a``. An attempt will be made to join to just this sub-component
116 first before looking at the full ``a`` construct, and if found
117 will be successful even if there are other ways to join to ``a``.
118 This allows the "right side" of a join to be passed thereby
119 providing a "natural join".
121 """
122 return Join._join_condition(
123 a,
124 b,
125 a_subset=a_subset,
126 consider_as_foreign_keys=consider_as_foreign_keys,
127 )
130def find_join_source(
131 clauses: List[FromClause], join_to: FromClause
132) -> List[int]:
133 """Given a list of FROM clauses and a selectable,
134 return the first index and element from the list of
135 clauses which can be joined against the selectable. returns
136 None, None if no match is found.
138 e.g.::
140 clause1 = table1.join(table2)
141 clause2 = table4.join(table5)
143 join_to = table2.join(table3)
145 find_join_source([clause1, clause2], join_to) == clause1
147 """
149 selectables = list(_from_objects(join_to))
150 idx = []
151 for i, f in enumerate(clauses):
152 for s in selectables:
153 if f.is_derived_from(s):
154 idx.append(i)
155 return idx
158def find_left_clause_that_matches_given(
159 clauses: Sequence[FromClause], join_from: FromClause
160) -> List[int]:
161 """Given a list of FROM clauses and a selectable,
162 return the indexes from the list of
163 clauses which is derived from the selectable.
165 """
167 selectables = list(_from_objects(join_from))
168 liberal_idx = []
169 for i, f in enumerate(clauses):
170 for s in selectables:
171 # basic check, if f is derived from s.
172 # this can be joins containing a table, or an aliased table
173 # or select statement matching to a table. This check
174 # will match a table to a selectable that is adapted from
175 # that table. With Query, this suits the case where a join
176 # is being made to an adapted entity
177 if f.is_derived_from(s):
178 liberal_idx.append(i)
179 break
181 # in an extremely small set of use cases, a join is being made where
182 # there are multiple FROM clauses where our target table is represented
183 # in more than one, such as embedded or similar. in this case, do
184 # another pass where we try to get a more exact match where we aren't
185 # looking at adaption relationships.
186 if len(liberal_idx) > 1:
187 conservative_idx = []
188 for idx in liberal_idx:
189 f = clauses[idx]
190 for s in selectables:
191 if set(surface_selectables(f)).intersection(
192 surface_selectables(s)
193 ):
194 conservative_idx.append(idx)
195 break
196 if conservative_idx:
197 return conservative_idx
199 return liberal_idx
202def find_left_clause_to_join_from(
203 clauses: Sequence[FromClause],
204 join_to: _JoinTargetElement,
205 onclause: Optional[ColumnElement[Any]],
206) -> List[int]:
207 """Given a list of FROM clauses, a selectable,
208 and optional ON clause, return a list of integer indexes from the
209 clauses list indicating the clauses that can be joined from.
211 The presence of an "onclause" indicates that at least one clause can
212 definitely be joined from; if the list of clauses is of length one
213 and the onclause is given, returns that index. If the list of clauses
214 is more than length one, and the onclause is given, attempts to locate
215 which clauses contain the same columns.
217 """
218 idx = []
219 selectables = set(_from_objects(join_to))
221 # if we are given more than one target clause to join
222 # from, use the onclause to provide a more specific answer.
223 # otherwise, don't try to limit, after all, "ON TRUE" is a valid
224 # on clause
225 if len(clauses) > 1 and onclause is not None:
226 resolve_ambiguity = True
227 cols_in_onclause = _find_columns(onclause)
228 else:
229 resolve_ambiguity = False
230 cols_in_onclause = None
232 for i, f in enumerate(clauses):
233 for s in selectables.difference([f]):
234 if resolve_ambiguity:
235 assert cols_in_onclause is not None
236 if set(f.c).union(s.c).issuperset(cols_in_onclause):
237 idx.append(i)
238 break
239 elif onclause is not None or Join._can_join(f, s):
240 idx.append(i)
241 break
243 if len(idx) > 1:
244 # this is the same "hide froms" logic from
245 # Selectable._get_display_froms
246 toremove = set(
247 chain(*[_expand_cloned(f._hide_froms) for f in clauses])
248 )
249 idx = [i for i in idx if clauses[i] not in toremove]
251 # onclause was given and none of them resolved, so assume
252 # all indexes can match
253 if not idx and onclause is not None:
254 return list(range(len(clauses)))
255 else:
256 return idx
259def visit_binary_product(
260 fn: Callable[
261 [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
262 ],
263 expr: ColumnElement[Any],
264) -> None:
265 """Produce a traversal of the given expression, delivering
266 column comparisons to the given function.
268 The function is of the form::
270 def my_fn(binary, left, right): ...
272 For each binary expression located which has a
273 comparison operator, the product of "left" and
274 "right" will be delivered to that function,
275 in terms of that binary.
277 Hence an expression like::
279 and_((a + b) == q + func.sum(e + f), j == r)
281 would have the traversal:
283 .. sourcecode:: text
285 a <eq> q
286 a <eq> e
287 a <eq> f
288 b <eq> q
289 b <eq> e
290 b <eq> f
291 j <eq> r
293 That is, every combination of "left" and
294 "right" that doesn't further contain
295 a binary comparison is passed as pairs.
297 """
298 stack: List[BinaryExpression[Any]] = []
300 def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
301 if isinstance(element, ScalarSelect):
302 # we don't want to dig into correlated subqueries,
303 # those are just column elements by themselves
304 yield element
305 elif element.__visit_name__ == "binary" and operators.is_comparison(
306 element.operator # type: ignore
307 ):
308 stack.insert(0, element) # type: ignore
309 for l in visit(element.left): # type: ignore
310 for r in visit(element.right): # type: ignore
311 fn(stack[0], l, r)
312 stack.pop(0)
313 for elem in element.get_children():
314 visit(elem)
315 else:
316 if isinstance(element, ColumnClause):
317 yield element
318 for elem in element.get_children():
319 yield from visit(elem)
321 list(visit(expr))
322 visit = None # type: ignore # remove gc cycles
325def find_tables(
326 clause: ClauseElement,
327 *,
328 check_columns: bool = False,
329 include_aliases: bool = False,
330 include_joins: bool = False,
331 include_selects: bool = False,
332 include_crud: bool = False,
333) -> List[TableClause]:
334 """locate Table objects within the given expression."""
336 tables: List[TableClause] = []
337 _visitors: Dict[str, _TraverseCallableType[Any]] = {}
339 if include_selects:
340 _visitors["select"] = _visitors["compound_select"] = tables.append
342 if include_joins:
343 _visitors["join"] = tables.append
345 if include_aliases:
346 _visitors["alias"] = _visitors["subquery"] = _visitors[
347 "tablesample"
348 ] = _visitors["lateral"] = tables.append
350 if include_crud:
351 _visitors["insert"] = _visitors["update"] = _visitors["delete"] = (
352 lambda ent: tables.append(ent.table)
353 )
355 if check_columns:
357 def visit_column(column):
358 tables.append(column.table)
360 _visitors["column"] = visit_column
362 _visitors["table"] = tables.append
364 visitors.traverse(clause, {}, _visitors)
365 return tables
368def unwrap_order_by(clause: Any) -> Any:
369 """Break up an 'order by' expression into individual column-expressions,
370 without DESC/ASC/NULLS FIRST/NULLS LAST"""
372 cols = util.column_set()
373 result = []
374 stack = deque([clause])
376 # examples
377 # column -> ASC/DESC == column
378 # column -> ASC/DESC -> label == column
379 # column -> label -> ASC/DESC -> label == column
380 # scalar_select -> label -> ASC/DESC == scalar_select -> label
382 while stack:
383 t = stack.popleft()
384 if isinstance(t, ColumnElement) and (
385 not isinstance(t, UnaryExpression)
386 or not operators.is_ordering_modifier(t.modifier) # type: ignore
387 ):
388 if isinstance(t, Label) and not isinstance(
389 t.element, ScalarSelect
390 ):
391 t = t.element
393 if isinstance(t, Grouping):
394 t = t.element
396 stack.append(t)
397 continue
398 elif isinstance(t, _label_reference):
399 t = t.element
401 stack.append(t)
402 continue
403 if isinstance(t, (_textual_label_reference)):
404 continue
405 if t not in cols:
406 cols.add(t)
407 result.append(t)
409 else:
410 for c in t.get_children():
411 stack.append(c)
412 return result
415def unwrap_label_reference(element):
416 def replace(
417 element: ExternallyTraversible, **kw: Any
418 ) -> Optional[ExternallyTraversible]:
419 if isinstance(element, _label_reference):
420 return element.element
421 elif isinstance(element, _textual_label_reference):
422 assert False, "can't unwrap a textual label reference"
423 return None
425 return visitors.replacement_traverse(element, {}, replace)
428def expand_column_list_from_order_by(collist, order_by):
429 """Given the columns clause and ORDER BY of a selectable,
430 return a list of column expressions that can be added to the collist
431 corresponding to the ORDER BY, without repeating those already
432 in the collist.
434 """
435 cols_already_present = {
436 col.element if col._order_by_label_element is not None else col
437 for col in collist
438 }
440 to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
442 return [col for col in to_look_for if col not in cols_already_present]
445def clause_is_present(clause, search):
446 """Given a target clause and a second to search within, return True
447 if the target is plainly present in the search without any
448 subqueries or aliases involved.
450 Basically descends through Joins.
452 """
454 for elem in surface_selectables(search):
455 if clause == elem: # use == here so that Annotated's compare
456 return True
457 else:
458 return False
461def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
462 if isinstance(clause, Join):
463 yield from tables_from_leftmost(clause.left)
464 yield from tables_from_leftmost(clause.right)
465 elif isinstance(clause, FromGrouping):
466 yield from tables_from_leftmost(clause.element)
467 else:
468 yield clause
471def surface_selectables(clause):
472 stack = [clause]
473 while stack:
474 elem = stack.pop()
475 yield elem
476 if isinstance(elem, Join):
477 stack.extend((elem.left, elem.right))
478 elif isinstance(elem, FromGrouping):
479 stack.append(elem.element)
482def surface_selectables_only(clause: ClauseElement) -> Iterator[ClauseElement]:
483 stack = [clause]
484 while stack:
485 elem = stack.pop()
486 if isinstance(elem, (TableClause, Alias)):
487 yield elem
488 if isinstance(elem, Join):
489 stack.extend((elem.left, elem.right))
490 elif isinstance(elem, FromGrouping):
491 stack.append(elem.element)
492 elif isinstance(elem, ColumnClause):
493 if elem.table is not None:
494 stack.append(elem.table)
495 else:
496 yield elem
497 elif elem is not None:
498 yield elem
501def extract_first_column_annotation(column, annotation_name):
502 filter_ = (FromGrouping, SelectBase)
504 stack = deque([column])
505 while stack:
506 elem = stack.popleft()
507 if annotation_name in elem._annotations:
508 return elem._annotations[annotation_name]
509 for sub in elem.get_children():
510 if isinstance(sub, filter_):
511 continue
512 stack.append(sub)
513 return None
516def selectables_overlap(left: FromClause, right: FromClause) -> bool:
517 """Return True if left/right have some overlapping selectable"""
519 return bool(
520 set(surface_selectables(left)).intersection(surface_selectables(right))
521 )
524def bind_values(clause):
525 """Return an ordered list of "bound" values in the given clause.
527 E.g.::
529 >>> expr = and_(table.c.foo == 5, table.c.foo == 7)
530 >>> bind_values(expr)
531 [5, 7]
532 """
534 v = []
536 def visit_bindparam(bind):
537 v.append(bind.effective_value)
539 visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
540 return v
543def _quote_ddl_expr(element):
544 if isinstance(element, str):
545 element = element.replace("'", "''")
546 return "'%s'" % element
547 else:
548 return repr(element)
551class _repr_base:
552 _LIST: int = 0
553 _TUPLE: int = 1
554 _DICT: int = 2
556 __slots__ = ("max_chars",)
558 max_chars: int
560 def trunc(self, value: Any) -> str:
561 rep = repr(value)
562 lenrep = len(rep)
563 if lenrep > self.max_chars:
564 segment_length = self.max_chars // 2
565 rep = (
566 rep[0:segment_length]
567 + (
568 " ... (%d characters truncated) ... "
569 % (lenrep - self.max_chars)
570 )
571 + rep[-segment_length:]
572 )
573 return rep
576def _repr_single_value(value):
577 rp = _repr_base()
578 rp.max_chars = 300
579 return rp.trunc(value)
582class _repr_row(_repr_base):
583 """Provide a string view of a row."""
585 __slots__ = ("row",)
587 def __init__(self, row: Row[Any], max_chars: int = 300):
588 self.row = row
589 self.max_chars = max_chars
591 def __repr__(self) -> str:
592 trunc = self.trunc
593 return "(%s%s)" % (
594 ", ".join(trunc(value) for value in self.row),
595 "," if len(self.row) == 1 else "",
596 )
599class _long_statement(str):
600 def __str__(self) -> str:
601 lself = len(self)
602 if lself > 500:
603 lleft = 250
604 lright = 100
605 trunc = lself - lleft - lright
606 return (
607 f"{self[0:lleft]} ... {trunc} "
608 f"characters truncated ... {self[-lright:]}"
609 )
610 else:
611 return str.__str__(self)
614class _repr_params(_repr_base):
615 """Provide a string view of bound parameters.
617 Truncates display to a given number of 'multi' parameter sets,
618 as well as long values to a given number of characters.
620 """
622 __slots__ = "params", "batches", "ismulti", "max_params"
624 def __init__(
625 self,
626 params: Optional[_AnyExecuteParams],
627 batches: int,
628 max_params: int = 100,
629 max_chars: int = 300,
630 ismulti: Optional[bool] = None,
631 ):
632 self.params = params
633 self.ismulti = ismulti
634 self.batches = batches
635 self.max_chars = max_chars
636 self.max_params = max_params
638 def __repr__(self) -> str:
639 if self.ismulti is None:
640 return self.trunc(self.params)
642 if isinstance(self.params, list):
643 typ = self._LIST
645 elif isinstance(self.params, tuple):
646 typ = self._TUPLE
647 elif isinstance(self.params, dict):
648 typ = self._DICT
649 else:
650 return self.trunc(self.params)
652 if self.ismulti:
653 multi_params = cast(
654 "_AnyMultiExecuteParams",
655 self.params,
656 )
658 if len(self.params) > self.batches:
659 msg = (
660 " ... displaying %i of %i total bound parameter sets ... "
661 )
662 return " ".join(
663 (
664 self._repr_multi(
665 multi_params[: self.batches - 2],
666 typ,
667 )[0:-1],
668 msg % (self.batches, len(self.params)),
669 self._repr_multi(multi_params[-2:], typ)[1:],
670 )
671 )
672 else:
673 return self._repr_multi(multi_params, typ)
674 else:
675 return self._repr_params(
676 cast(
677 "_AnySingleExecuteParams",
678 self.params,
679 ),
680 typ,
681 )
683 def _repr_multi(
684 self,
685 multi_params: _AnyMultiExecuteParams,
686 typ: int,
687 ) -> str:
688 if multi_params:
689 if isinstance(multi_params[0], list):
690 elem_type = self._LIST
691 elif isinstance(multi_params[0], tuple):
692 elem_type = self._TUPLE
693 elif isinstance(multi_params[0], dict):
694 elem_type = self._DICT
695 else:
696 assert False, "Unknown parameter type %s" % (
697 type(multi_params[0])
698 )
700 elements = ", ".join(
701 self._repr_params(params, elem_type) for params in multi_params
702 )
703 else:
704 elements = ""
706 if typ == self._LIST:
707 return "[%s]" % elements
708 else:
709 return "(%s)" % elements
711 def _get_batches(self, params: Iterable[Any]) -> Any:
712 lparams = list(params)
713 lenparams = len(lparams)
714 if lenparams > self.max_params:
715 lleft = self.max_params // 2
716 return (
717 lparams[0:lleft],
718 lparams[-lleft:],
719 lenparams - self.max_params,
720 )
721 else:
722 return lparams, None, None
724 def _repr_params(
725 self,
726 params: _AnySingleExecuteParams,
727 typ: int,
728 ) -> str:
729 if typ is self._DICT:
730 return self._repr_param_dict(
731 cast("_CoreSingleExecuteParams", params)
732 )
733 elif typ is self._TUPLE:
734 return self._repr_param_tuple(cast("Sequence[Any]", params))
735 else:
736 return self._repr_param_list(params)
738 def _repr_param_dict(self, params: _CoreSingleExecuteParams) -> str:
739 trunc = self.trunc
740 (
741 items_first_batch,
742 items_second_batch,
743 trunclen,
744 ) = self._get_batches(params.items())
746 if items_second_batch:
747 text = "{%s" % (
748 ", ".join(
749 f"{key!r}: {trunc(value)}"
750 for key, value in items_first_batch
751 )
752 )
753 text += f" ... {trunclen} parameters truncated ... "
754 text += "%s}" % (
755 ", ".join(
756 f"{key!r}: {trunc(value)}"
757 for key, value in items_second_batch
758 )
759 )
760 else:
761 text = "{%s}" % (
762 ", ".join(
763 f"{key!r}: {trunc(value)}"
764 for key, value in items_first_batch
765 )
766 )
767 return text
769 def _repr_param_tuple(self, params: Sequence[Any]) -> str:
770 trunc = self.trunc
772 (
773 items_first_batch,
774 items_second_batch,
775 trunclen,
776 ) = self._get_batches(params)
778 if items_second_batch:
779 text = "(%s" % (
780 ", ".join(trunc(value) for value in items_first_batch)
781 )
782 text += f" ... {trunclen} parameters truncated ... "
783 text += "%s)" % (
784 ", ".join(trunc(value) for value in items_second_batch),
785 )
786 else:
787 text = "(%s%s)" % (
788 ", ".join(trunc(value) for value in items_first_batch),
789 "," if len(items_first_batch) == 1 else "",
790 )
791 return text
793 def _repr_param_list(self, params: _AnySingleExecuteParams) -> str:
794 trunc = self.trunc
795 (
796 items_first_batch,
797 items_second_batch,
798 trunclen,
799 ) = self._get_batches(params)
801 if items_second_batch:
802 text = "[%s" % (
803 ", ".join(trunc(value) for value in items_first_batch)
804 )
805 text += f" ... {trunclen} parameters truncated ... "
806 text += "%s]" % (
807 ", ".join(trunc(value) for value in items_second_batch)
808 )
809 else:
810 text = "[%s]" % (
811 ", ".join(trunc(value) for value in items_first_batch)
812 )
813 return text
816def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
817 """given criterion containing bind params, convert selected elements
818 to IS NULL.
820 """
822 def visit_binary(binary):
823 if (
824 isinstance(binary.left, BindParameter)
825 and binary.left._identifying_key in nulls
826 ):
827 # reverse order if the NULL is on the left side
828 binary.left = binary.right
829 binary.right = Null()
830 binary.operator = operators.is_
831 binary.negate = operators.is_not
832 elif (
833 isinstance(binary.right, BindParameter)
834 and binary.right._identifying_key in nulls
835 ):
836 binary.right = Null()
837 binary.operator = operators.is_
838 binary.negate = operators.is_not
840 return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
843def splice_joins(
844 left: Optional[FromClause],
845 right: Optional[FromClause],
846 stop_on: Optional[FromClause] = None,
847) -> Optional[FromClause]:
848 if left is None:
849 return right
851 stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)]
853 adapter = ClauseAdapter(left)
854 ret = None
855 while stack:
856 right, prevright = stack.pop()
857 if isinstance(right, Join) and right is not stop_on:
858 right = right._clone()
859 right.onclause = adapter.traverse(right.onclause)
860 stack.append((right.left, right))
861 else:
862 right = adapter.traverse(right)
863 if prevright is not None:
864 assert right is not None
865 prevright.left = right
866 if ret is None:
867 ret = right
869 return ret
872@overload
873def reduce_columns(
874 columns: Iterable[ColumnElement[Any]],
875 *clauses: Optional[ClauseElement],
876 **kw: bool,
877) -> Sequence[ColumnElement[Any]]: ...
880@overload
881def reduce_columns(
882 columns: _SelectIterable,
883 *clauses: Optional[ClauseElement],
884 **kw: bool,
885) -> Sequence[Union[ColumnElement[Any], TextClause]]: ...
888def reduce_columns(
889 columns: _SelectIterable,
890 *clauses: Optional[ClauseElement],
891 **kw: bool,
892) -> Collection[Union[ColumnElement[Any], TextClause]]:
893 r"""given a list of columns, return a 'reduced' set based on natural
894 equivalents.
896 the set is reduced to the smallest list of columns which have no natural
897 equivalent present in the list. A "natural equivalent" means that two
898 columns will ultimately represent the same value because they are related
899 by a foreign key.
901 \*clauses is an optional list of join clauses which will be traversed
902 to further identify columns that are "equivalent".
904 \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
905 whose tables are not yet configured, or columns that aren't yet present.
907 This function is primarily used to determine the most minimal "primary
908 key" from a selectable, by reducing the set of primary key columns present
909 in the selectable to just those that are not repeated.
911 """
912 ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
913 only_synonyms = kw.pop("only_synonyms", False)
915 column_set = util.OrderedSet(columns)
916 cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
917 c for c in column_set if is_text_clause(c) # type: ignore
918 )
920 omit = util.column_set()
921 for col in cset_no_text:
922 for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
923 for c in cset_no_text:
924 if c is col:
925 continue
926 try:
927 fk_col = fk.column
928 except exc.NoReferencedColumnError:
929 # TODO: add specific coverage here
930 # to test/sql/test_selectable ReduceTest
931 if ignore_nonexistent_tables:
932 continue
933 else:
934 raise
935 except exc.NoReferencedTableError:
936 # TODO: add specific coverage here
937 # to test/sql/test_selectable ReduceTest
938 if ignore_nonexistent_tables:
939 continue
940 else:
941 raise
942 if fk_col.shares_lineage(c) and (
943 not only_synonyms or c.name == col.name
944 ):
945 omit.add(col)
946 break
948 if clauses:
950 def visit_binary(binary):
951 if binary.operator == operators.eq:
952 cols = util.column_set(
953 chain(
954 *[c.proxy_set for c in cset_no_text.difference(omit)]
955 )
956 )
957 if binary.left in cols and binary.right in cols:
958 for c in reversed(cset_no_text):
959 if c.shares_lineage(binary.right) and (
960 not only_synonyms or c.name == binary.left.name
961 ):
962 omit.add(c)
963 break
965 for clause in clauses:
966 if clause is not None:
967 visitors.traverse(clause, {}, {"binary": visit_binary})
969 return column_set.difference(omit)
972def criterion_as_pairs(
973 expression,
974 consider_as_foreign_keys=None,
975 consider_as_referenced_keys=None,
976 any_operator=False,
977):
978 """traverse an expression and locate binary criterion pairs."""
980 if consider_as_foreign_keys and consider_as_referenced_keys:
981 raise exc.ArgumentError(
982 "Can only specify one of "
983 "'consider_as_foreign_keys' or "
984 "'consider_as_referenced_keys'"
985 )
987 def col_is(a, b):
988 # return a is b
989 return a.compare(b)
991 def visit_binary(binary):
992 if not any_operator and binary.operator is not operators.eq:
993 return
994 if not isinstance(binary.left, ColumnElement) or not isinstance(
995 binary.right, ColumnElement
996 ):
997 return
999 if consider_as_foreign_keys:
1000 if binary.left in consider_as_foreign_keys and (
1001 col_is(binary.right, binary.left)
1002 or binary.right not in consider_as_foreign_keys
1003 ):
1004 pairs.append((binary.right, binary.left))
1005 elif binary.right in consider_as_foreign_keys and (
1006 col_is(binary.left, binary.right)
1007 or binary.left not in consider_as_foreign_keys
1008 ):
1009 pairs.append((binary.left, binary.right))
1010 elif consider_as_referenced_keys:
1011 if binary.left in consider_as_referenced_keys and (
1012 col_is(binary.right, binary.left)
1013 or binary.right not in consider_as_referenced_keys
1014 ):
1015 pairs.append((binary.left, binary.right))
1016 elif binary.right in consider_as_referenced_keys and (
1017 col_is(binary.left, binary.right)
1018 or binary.left not in consider_as_referenced_keys
1019 ):
1020 pairs.append((binary.right, binary.left))
1021 else:
1022 if isinstance(binary.left, Column) and isinstance(
1023 binary.right, Column
1024 ):
1025 if binary.left.references(binary.right):
1026 pairs.append((binary.right, binary.left))
1027 elif binary.right.references(binary.left):
1028 pairs.append((binary.left, binary.right))
1030 pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = []
1031 visitors.traverse(expression, {}, {"binary": visit_binary})
1032 return pairs
1035class ClauseAdapter(visitors.ReplacingExternalTraversal):
1036 """Clones and modifies clauses based on column correspondence.
1038 E.g.::
1040 table1 = Table(
1041 "sometable",
1042 metadata,
1043 Column("col1", Integer),
1044 Column("col2", Integer),
1045 )
1046 table2 = Table(
1047 "someothertable",
1048 metadata,
1049 Column("col1", Integer),
1050 Column("col2", Integer),
1051 )
1053 condition = table1.c.col1 == table2.c.col1
1055 make an alias of table1::
1057 s = table1.alias("foo")
1059 calling ``ClauseAdapter(s).traverse(condition)`` converts
1060 condition to read::
1062 s.c.col1 == table2.c.col1
1064 """
1066 __slots__ = (
1067 "__traverse_options__",
1068 "selectable",
1069 "include_fn",
1070 "exclude_fn",
1071 "equivalents",
1072 "adapt_on_names",
1073 "adapt_from_selectables",
1074 )
1076 def __init__(
1077 self,
1078 selectable: Selectable,
1079 equivalents: Optional[_EquivalentColumnMap] = None,
1080 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1081 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1082 adapt_on_names: bool = False,
1083 anonymize_labels: bool = False,
1084 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1085 ):
1086 self.__traverse_options__ = {
1087 "stop_on": [selectable],
1088 "anonymize_labels": anonymize_labels,
1089 }
1090 self.selectable = selectable
1091 self.include_fn = include_fn
1092 self.exclude_fn = exclude_fn
1093 self.equivalents = util.column_dict(equivalents or {})
1094 self.adapt_on_names = adapt_on_names
1095 self.adapt_from_selectables = adapt_from_selectables
1097 if TYPE_CHECKING:
1099 @overload
1100 def traverse(self, obj: Literal[None]) -> None: ...
1102 # note this specializes the ReplacingExternalTraversal.traverse()
1103 # method to state
1104 # that we will return the same kind of ExternalTraversal object as
1105 # we were given. This is probably not 100% true, such as it's
1106 # possible for us to swap out Alias for Table at the top level.
1107 # Ideally there could be overloads specific to ColumnElement and
1108 # FromClause but Mypy is not accepting those as compatible with
1109 # the base ReplacingExternalTraversal
1110 @overload
1111 def traverse(self, obj: _ET) -> _ET: ...
1113 def traverse(
1114 self, obj: Optional[ExternallyTraversible]
1115 ) -> Optional[ExternallyTraversible]: ...
1117 def _corresponding_column(
1118 self, col, require_embedded, _seen=util.EMPTY_SET
1119 ):
1120 newcol = self.selectable.corresponding_column(
1121 col, require_embedded=require_embedded
1122 )
1123 if newcol is None and col in self.equivalents and col not in _seen:
1124 for equiv in self.equivalents[col]:
1125 newcol = self._corresponding_column(
1126 equiv,
1127 require_embedded=require_embedded,
1128 _seen=_seen.union([col]),
1129 )
1130 if newcol is not None:
1131 return newcol
1133 if (
1134 self.adapt_on_names
1135 and newcol is None
1136 and isinstance(col, NamedColumn)
1137 ):
1138 newcol = self.selectable.exported_columns.get(col.name)
1139 return newcol
1141 @util.preload_module("sqlalchemy.sql.functions")
1142 def replace(
1143 self, col: _ET, _include_singleton_constants: bool = False
1144 ) -> Optional[_ET]:
1145 functions = util.preloaded.sql_functions
1147 # TODO: cython candidate
1149 if self.include_fn and not self.include_fn(col): # type: ignore
1150 return None
1151 elif self.exclude_fn and self.exclude_fn(col): # type: ignore
1152 return None
1154 if isinstance(col, FromClause) and not isinstance(
1155 col, functions.FunctionElement
1156 ):
1157 if self.selectable.is_derived_from(col):
1158 if self.adapt_from_selectables:
1159 for adp in self.adapt_from_selectables:
1160 if adp.is_derived_from(col):
1161 break
1162 else:
1163 return None
1164 return self.selectable # type: ignore
1165 elif isinstance(col, Alias) and isinstance(
1166 col.element, TableClause
1167 ):
1168 # we are a SELECT statement and not derived from an alias of a
1169 # table (which nonetheless may be a table our SELECT derives
1170 # from), so return the alias to prevent further traversal
1171 # or
1172 # we are an alias of a table and we are not derived from an
1173 # alias of a table (which nonetheless may be the same table
1174 # as ours) so, same thing
1175 return col
1176 else:
1177 # other cases where we are a selectable and the element
1178 # is another join or selectable that contains a table which our
1179 # selectable derives from, that we want to process
1180 return None
1182 elif not isinstance(col, ColumnElement):
1183 return None
1184 elif not _include_singleton_constants and col._is_singleton_constant:
1185 # dont swap out NULL, TRUE, FALSE for a label name
1186 # in a SQL statement that's being rewritten,
1187 # leave them as the constant. This is first noted in #6259,
1188 # however the logic to check this moved here as of #7154 so that
1189 # it is made specific to SQL rewriting and not all column
1190 # correspondence
1192 return None
1194 if "adapt_column" in col._annotations:
1195 col = col._annotations["adapt_column"]
1197 if TYPE_CHECKING:
1198 assert isinstance(col, KeyedColumnElement)
1200 if self.adapt_from_selectables and col not in self.equivalents:
1201 for adp in self.adapt_from_selectables:
1202 if adp.c.corresponding_column(col, False) is not None:
1203 break
1204 else:
1205 return None
1207 if TYPE_CHECKING:
1208 assert isinstance(col, KeyedColumnElement)
1210 return self._corresponding_column( # type: ignore
1211 col, require_embedded=True
1212 )
1215class _ColumnLookup(Protocol):
1216 @overload
1217 def __getitem__(self, key: None) -> None: ...
1219 @overload
1220 def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: ...
1222 @overload
1223 def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: ...
1225 @overload
1226 def __getitem__(self, key: _ET) -> _ET: ...
1228 def __getitem__(self, key: Any) -> Any: ...
1231class ColumnAdapter(ClauseAdapter):
1232 """Extends ClauseAdapter with extra utility functions.
1234 Key aspects of ColumnAdapter include:
1236 * Expressions that are adapted are stored in a persistent
1237 .columns collection; so that an expression E adapted into
1238 an expression E1, will return the same object E1 when adapted
1239 a second time. This is important in particular for things like
1240 Label objects that are anonymized, so that the ColumnAdapter can
1241 be used to present a consistent "adapted" view of things.
1243 * Exclusion of items from the persistent collection based on
1244 include/exclude rules, but also independent of hash identity.
1245 This because "annotated" items all have the same hash identity as their
1246 parent.
1248 * "wrapping" capability is added, so that the replacement of an expression
1249 E can proceed through a series of adapters. This differs from the
1250 visitor's "chaining" feature in that the resulting object is passed
1251 through all replacing functions unconditionally, rather than stopping
1252 at the first one that returns non-None.
1254 * An adapt_required option, used by eager loading to indicate that
1255 We don't trust a result row column that is not translated.
1256 This is to prevent a column from being interpreted as that
1257 of the child row in a self-referential scenario, see
1258 inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
1260 """
1262 __slots__ = (
1263 "columns",
1264 "adapt_required",
1265 "allow_label_resolve",
1266 "_wrap",
1267 "__weakref__",
1268 )
1270 columns: _ColumnLookup
1272 def __init__(
1273 self,
1274 selectable: Selectable,
1275 equivalents: Optional[_EquivalentColumnMap] = None,
1276 adapt_required: bool = False,
1277 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1278 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1279 adapt_on_names: bool = False,
1280 allow_label_resolve: bool = True,
1281 anonymize_labels: bool = False,
1282 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1283 ):
1284 super().__init__(
1285 selectable,
1286 equivalents,
1287 include_fn=include_fn,
1288 exclude_fn=exclude_fn,
1289 adapt_on_names=adapt_on_names,
1290 anonymize_labels=anonymize_labels,
1291 adapt_from_selectables=adapt_from_selectables,
1292 )
1294 self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
1295 if self.include_fn or self.exclude_fn:
1296 self.columns = self._IncludeExcludeMapping(self, self.columns)
1297 self.adapt_required = adapt_required
1298 self.allow_label_resolve = allow_label_resolve
1299 self._wrap = None
1301 class _IncludeExcludeMapping:
1302 def __init__(self, parent, columns):
1303 self.parent = parent
1304 self.columns = columns
1306 def __getitem__(self, key):
1307 if (
1308 self.parent.include_fn and not self.parent.include_fn(key)
1309 ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
1310 if self.parent._wrap:
1311 return self.parent._wrap.columns[key]
1312 else:
1313 return key
1314 return self.columns[key]
1316 def wrap(self, adapter):
1317 ac = copy.copy(self)
1318 ac._wrap = adapter
1319 ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore
1320 if ac.include_fn or ac.exclude_fn:
1321 ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
1323 return ac
1325 @overload
1326 def traverse(self, obj: Literal[None]) -> None: ...
1328 @overload
1329 def traverse(self, obj: _ET) -> _ET: ...
1331 def traverse(
1332 self, obj: Optional[ExternallyTraversible]
1333 ) -> Optional[ExternallyTraversible]:
1334 return self.columns[obj]
1336 def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
1337 assert isinstance(visitor, ColumnAdapter)
1339 return super().chain(visitor)
1341 if TYPE_CHECKING:
1343 @property
1344 def visitor_iterator(self) -> Iterator[ColumnAdapter]: ...
1346 adapt_clause = traverse
1347 adapt_list = ClauseAdapter.copy_and_process
1349 def adapt_check_present(
1350 self, col: ColumnElement[Any]
1351 ) -> Optional[ColumnElement[Any]]:
1352 newcol = self.columns[col]
1354 if newcol is col and self._corresponding_column(col, True) is None:
1355 return None
1357 return newcol
1359 def _locate_col(
1360 self, col: ColumnElement[Any]
1361 ) -> Optional[ColumnElement[Any]]:
1362 # both replace and traverse() are overly complicated for what
1363 # we are doing here and we would do better to have an inlined
1364 # version that doesn't build up as much overhead. the issue is that
1365 # sometimes the lookup does in fact have to adapt the insides of
1366 # say a labeled scalar subquery. However, if the object is an
1367 # Immutable, i.e. Column objects, we can skip the "clone" /
1368 # "copy internals" part since those will be no-ops in any case.
1369 # additionally we want to catch singleton objects null/true/false
1370 # and make sure they are adapted as well here.
1372 if col._is_immutable:
1373 for vis in self.visitor_iterator:
1374 c = vis.replace(col, _include_singleton_constants=True)
1375 if c is not None:
1376 break
1377 else:
1378 c = col
1379 else:
1380 c = ClauseAdapter.traverse(self, col)
1382 if self._wrap:
1383 c2 = self._wrap._locate_col(c)
1384 if c2 is not None:
1385 c = c2
1387 if self.adapt_required and c is col:
1388 return None
1390 # allow_label_resolve is consumed by one case for joined eager loading
1391 # as part of its logic to prevent its own columns from being affected
1392 # by .order_by(). Before full typing were applied to the ORM, this
1393 # logic would set this attribute on the incoming object (which is
1394 # typically a column, but we have a test for it being a non-column
1395 # object) if no column were found. While this seemed to
1396 # have no negative effects, this adjustment should only occur on the
1397 # new column which is assumed to be local to an adapted selectable.
1398 if c is not col:
1399 c._allow_label_resolve = self.allow_label_resolve
1401 return c
1404def _offset_or_limit_clause(
1405 element: _LimitOffsetType,
1406 name: Optional[str] = None,
1407 type_: Optional[_TypeEngineArgument[int]] = None,
1408) -> ColumnElement[int]:
1409 """Convert the given value to an "offset or limit" clause.
1411 This handles incoming integers and converts to an expression; if
1412 an expression is already given, it is passed through.
1414 """
1415 return coercions.expect(
1416 roles.LimitOffsetRole, element, name=name, type_=type_
1417 )
1420def _offset_or_limit_clause_asint_if_possible(
1421 clause: _LimitOffsetType,
1422) -> _LimitOffsetType:
1423 """Return the offset or limit clause as a simple integer if possible,
1424 else return the clause.
1426 """
1427 if clause is None:
1428 return None
1429 if hasattr(clause, "_limit_offset_value"):
1430 value = clause._limit_offset_value
1431 return util.asint(value)
1432 else:
1433 return clause
1436def _make_slice(
1437 limit_clause: _LimitOffsetType,
1438 offset_clause: _LimitOffsetType,
1439 start: int,
1440 stop: int,
1441) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]:
1442 """Compute LIMIT/OFFSET in terms of slice start/end"""
1444 # for calculated limit/offset, try to do the addition of
1445 # values to offset in Python, however if a SQL clause is present
1446 # then the addition has to be on the SQL side.
1448 # TODO: typing is finding a few gaps in here, see if they can be
1449 # closed up
1451 if start is not None and stop is not None:
1452 offset_clause = _offset_or_limit_clause_asint_if_possible(
1453 offset_clause
1454 )
1455 if offset_clause is None:
1456 offset_clause = 0
1458 if start != 0:
1459 offset_clause = offset_clause + start # type: ignore
1461 if offset_clause == 0:
1462 offset_clause = None
1463 else:
1464 assert offset_clause is not None
1465 offset_clause = _offset_or_limit_clause(offset_clause)
1467 limit_clause = _offset_or_limit_clause(stop - start)
1469 elif start is None and stop is not None:
1470 limit_clause = _offset_or_limit_clause(stop)
1471 elif start is not None and stop is None:
1472 offset_clause = _offset_or_limit_clause_asint_if_possible(
1473 offset_clause
1474 )
1475 if offset_clause is None:
1476 offset_clause = 0
1478 if start != 0:
1479 offset_clause = offset_clause + start
1481 if offset_clause == 0:
1482 offset_clause = None
1483 else:
1484 offset_clause = _offset_or_limit_clause(offset_clause)
1486 return limit_clause, offset_clause