1# sql/util.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""High level utilities which build upon other modules here.
10
11"""
12from __future__ import annotations
13
14from collections import deque
15import copy
16from itertools import chain
17import typing
18from typing import AbstractSet
19from typing import Any
20from typing import Callable
21from typing import cast
22from typing import Collection
23from typing import Dict
24from typing import Iterable
25from typing import Iterator
26from typing import List
27from typing import Optional
28from typing import overload
29from typing import Sequence
30from typing import Tuple
31from typing import TYPE_CHECKING
32from typing import TypeVar
33from typing import Union
34
35from . import coercions
36from . import operators
37from . import roles
38from . import visitors
39from ._typing import is_text_clause
40from .annotation import _deep_annotate as _deep_annotate # noqa: F401
41from .annotation import _deep_deannotate as _deep_deannotate # noqa: F401
42from .annotation import _shallow_annotate as _shallow_annotate # noqa: F401
43from .base import _expand_cloned
44from .base import _from_objects
45from .cache_key import HasCacheKey as HasCacheKey # noqa: F401
46from .ddl import sort_tables as sort_tables # noqa: F401
47from .elements import _find_columns as _find_columns
48from .elements import _label_reference
49from .elements import _textual_label_reference
50from .elements import BindParameter
51from .elements import ClauseElement
52from .elements import ColumnClause
53from .elements import ColumnElement
54from .elements import Grouping
55from .elements import KeyedColumnElement
56from .elements import Label
57from .elements import NamedColumn
58from .elements import Null
59from .elements import UnaryExpression
60from .schema import Column
61from .selectable import Alias
62from .selectable import FromClause
63from .selectable import FromGrouping
64from .selectable import Join
65from .selectable import ScalarSelect
66from .selectable import SelectBase
67from .selectable import TableClause
68from .visitors import _ET
69from .. import exc
70from .. import util
71from ..util.typing import Literal
72from ..util.typing import Protocol
73
74if typing.TYPE_CHECKING:
75 from ._typing import _EquivalentColumnMap
76 from ._typing import _LimitOffsetType
77 from ._typing import _TypeEngineArgument
78 from .elements import BinaryExpression
79 from .elements import TextClause
80 from .selectable import _JoinTargetElement
81 from .selectable import _SelectIterable
82 from .selectable import Selectable
83 from .visitors import _TraverseCallableType
84 from .visitors import ExternallyTraversible
85 from .visitors import ExternalTraversal
86 from ..engine.interfaces import _AnyExecuteParams
87 from ..engine.interfaces import _AnyMultiExecuteParams
88 from ..engine.interfaces import _AnySingleExecuteParams
89 from ..engine.interfaces import _CoreSingleExecuteParams
90 from ..engine.row import Row
91
92_CE = TypeVar("_CE", bound="ColumnElement[Any]")
93
94
95def join_condition(
96 a: FromClause,
97 b: FromClause,
98 a_subset: Optional[FromClause] = None,
99 consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
100) -> ColumnElement[bool]:
101 """Create a join condition between two tables or selectables.
102
103 e.g.::
104
105 join_condition(tablea, tableb)
106
107 would produce an expression along the lines of::
108
109 tablea.c.id == tableb.c.tablea_id
110
111 The join is determined based on the foreign key relationships
112 between the two selectables. If there are multiple ways
113 to join, or no way to join, an error is raised.
114
115 :param a_subset: An optional expression that is a sub-component
116 of ``a``. An attempt will be made to join to just this sub-component
117 first before looking at the full ``a`` construct, and if found
118 will be successful even if there are other ways to join to ``a``.
119 This allows the "right side" of a join to be passed thereby
120 providing a "natural join".
121
122 """
123 return Join._join_condition(
124 a,
125 b,
126 a_subset=a_subset,
127 consider_as_foreign_keys=consider_as_foreign_keys,
128 )
129
130
131def find_join_source(
132 clauses: List[FromClause], join_to: FromClause
133) -> List[int]:
134 """Given a list of FROM clauses and a selectable,
135 return the first index and element from the list of
136 clauses which can be joined against the selectable. returns
137 None, None if no match is found.
138
139 e.g.::
140
141 clause1 = table1.join(table2)
142 clause2 = table4.join(table5)
143
144 join_to = table2.join(table3)
145
146 find_join_source([clause1, clause2], join_to) == clause1
147
148 """
149
150 selectables = list(_from_objects(join_to))
151 idx = []
152 for i, f in enumerate(clauses):
153 for s in selectables:
154 if f.is_derived_from(s):
155 idx.append(i)
156 return idx
157
158
159def find_left_clause_that_matches_given(
160 clauses: Sequence[FromClause], join_from: FromClause
161) -> List[int]:
162 """Given a list of FROM clauses and a selectable,
163 return the indexes from the list of
164 clauses which is derived from the selectable.
165
166 """
167
168 selectables = list(_from_objects(join_from))
169 liberal_idx = []
170 for i, f in enumerate(clauses):
171 for s in selectables:
172 # basic check, if f is derived from s.
173 # this can be joins containing a table, or an aliased table
174 # or select statement matching to a table. This check
175 # will match a table to a selectable that is adapted from
176 # that table. With Query, this suits the case where a join
177 # is being made to an adapted entity
178 if f.is_derived_from(s):
179 liberal_idx.append(i)
180 break
181
182 # in an extremely small set of use cases, a join is being made where
183 # there are multiple FROM clauses where our target table is represented
184 # in more than one, such as embedded or similar. in this case, do
185 # another pass where we try to get a more exact match where we aren't
186 # looking at adaption relationships.
187 if len(liberal_idx) > 1:
188 conservative_idx = []
189 for idx in liberal_idx:
190 f = clauses[idx]
191 for s in selectables:
192 if set(surface_selectables(f)).intersection(
193 surface_selectables(s)
194 ):
195 conservative_idx.append(idx)
196 break
197 if conservative_idx:
198 return conservative_idx
199
200 return liberal_idx
201
202
203def find_left_clause_to_join_from(
204 clauses: Sequence[FromClause],
205 join_to: _JoinTargetElement,
206 onclause: Optional[ColumnElement[Any]],
207) -> List[int]:
208 """Given a list of FROM clauses, a selectable,
209 and optional ON clause, return a list of integer indexes from the
210 clauses list indicating the clauses that can be joined from.
211
212 The presence of an "onclause" indicates that at least one clause can
213 definitely be joined from; if the list of clauses is of length one
214 and the onclause is given, returns that index. If the list of clauses
215 is more than length one, and the onclause is given, attempts to locate
216 which clauses contain the same columns.
217
218 """
219 idx = []
220 selectables = set(_from_objects(join_to))
221
222 # if we are given more than one target clause to join
223 # from, use the onclause to provide a more specific answer.
224 # otherwise, don't try to limit, after all, "ON TRUE" is a valid
225 # on clause
226 if len(clauses) > 1 and onclause is not None:
227 resolve_ambiguity = True
228 cols_in_onclause = _find_columns(onclause)
229 else:
230 resolve_ambiguity = False
231 cols_in_onclause = None
232
233 for i, f in enumerate(clauses):
234 for s in selectables.difference([f]):
235 if resolve_ambiguity:
236 assert cols_in_onclause is not None
237 if set(f.c).union(s.c).issuperset(cols_in_onclause):
238 idx.append(i)
239 break
240 elif onclause is not None or Join._can_join(f, s):
241 idx.append(i)
242 break
243
244 if len(idx) > 1:
245 # this is the same "hide froms" logic from
246 # Selectable._get_display_froms
247 toremove = set(
248 chain(*[_expand_cloned(f._hide_froms) for f in clauses])
249 )
250 idx = [i for i in idx if clauses[i] not in toremove]
251
252 # onclause was given and none of them resolved, so assume
253 # all indexes can match
254 if not idx and onclause is not None:
255 return list(range(len(clauses)))
256 else:
257 return idx
258
259
260def visit_binary_product(
261 fn: Callable[
262 [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
263 ],
264 expr: ColumnElement[Any],
265) -> None:
266 """Produce a traversal of the given expression, delivering
267 column comparisons to the given function.
268
269 The function is of the form::
270
271 def my_fn(binary, left, right): ...
272
273 For each binary expression located which has a
274 comparison operator, the product of "left" and
275 "right" will be delivered to that function,
276 in terms of that binary.
277
278 Hence an expression like::
279
280 and_((a + b) == q + func.sum(e + f), j == r)
281
282 would have the traversal:
283
284 .. sourcecode:: text
285
286 a <eq> q
287 a <eq> e
288 a <eq> f
289 b <eq> q
290 b <eq> e
291 b <eq> f
292 j <eq> r
293
294 That is, every combination of "left" and
295 "right" that doesn't further contain
296 a binary comparison is passed as pairs.
297
298 """
299 stack: List[BinaryExpression[Any]] = []
300
301 def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
302 if isinstance(element, ScalarSelect):
303 # we don't want to dig into correlated subqueries,
304 # those are just column elements by themselves
305 yield element
306 elif element.__visit_name__ == "binary" and operators.is_comparison(
307 element.operator # type: ignore
308 ):
309 stack.insert(0, element) # type: ignore
310 for l in visit(element.left): # type: ignore
311 for r in visit(element.right): # type: ignore
312 fn(stack[0], l, r)
313 stack.pop(0)
314 for elem in element.get_children():
315 visit(elem)
316 else:
317 if isinstance(element, ColumnClause):
318 yield element
319 for elem in element.get_children():
320 yield from visit(elem)
321
322 list(visit(expr))
323 visit = None # type: ignore # remove gc cycles
324
325
326def find_tables(
327 clause: ClauseElement,
328 *,
329 check_columns: bool = False,
330 include_aliases: bool = False,
331 include_joins: bool = False,
332 include_selects: bool = False,
333 include_crud: bool = False,
334) -> List[TableClause]:
335 """locate Table objects within the given expression."""
336
337 tables: List[TableClause] = []
338 _visitors: Dict[str, _TraverseCallableType[Any]] = {}
339
340 if include_selects:
341 _visitors["select"] = _visitors["compound_select"] = tables.append
342
343 if include_joins:
344 _visitors["join"] = tables.append
345
346 if include_aliases:
347 _visitors["alias"] = _visitors["subquery"] = _visitors[
348 "tablesample"
349 ] = _visitors["lateral"] = tables.append
350
351 if include_crud:
352 _visitors["insert"] = _visitors["update"] = _visitors["delete"] = (
353 lambda ent: tables.append(ent.table)
354 )
355
356 if check_columns:
357
358 def visit_column(column):
359 tables.append(column.table)
360
361 _visitors["column"] = visit_column
362
363 _visitors["table"] = tables.append
364
365 visitors.traverse(clause, {}, _visitors)
366 return tables
367
368
369def unwrap_order_by(clause: Any) -> Any:
370 """Break up an 'order by' expression into individual column-expressions,
371 without DESC/ASC/NULLS FIRST/NULLS LAST"""
372
373 cols = util.column_set()
374 result = []
375 stack = deque([clause])
376
377 # examples
378 # column -> ASC/DESC == column
379 # column -> ASC/DESC -> label == column
380 # column -> label -> ASC/DESC -> label == column
381 # scalar_select -> label -> ASC/DESC == scalar_select -> label
382
383 while stack:
384 t = stack.popleft()
385 if isinstance(t, ColumnElement) and (
386 not isinstance(t, UnaryExpression)
387 or not operators.is_ordering_modifier(t.modifier) # type: ignore
388 ):
389 if isinstance(t, Label) and not isinstance(
390 t.element, ScalarSelect
391 ):
392 t = t.element
393
394 if isinstance(t, Grouping):
395 t = t.element
396
397 stack.append(t)
398 continue
399 elif isinstance(t, _label_reference):
400 t = t.element
401
402 stack.append(t)
403 continue
404 if isinstance(t, (_textual_label_reference)):
405 continue
406 if t not in cols:
407 cols.add(t)
408 result.append(t)
409
410 else:
411 for c in t.get_children():
412 stack.append(c)
413 return result
414
415
416def unwrap_label_reference(element):
417 def replace(
418 element: ExternallyTraversible, **kw: Any
419 ) -> Optional[ExternallyTraversible]:
420 if isinstance(element, _label_reference):
421 return element.element
422 elif isinstance(element, _textual_label_reference):
423 assert False, "can't unwrap a textual label reference"
424 return None
425
426 return visitors.replacement_traverse(element, {}, replace)
427
428
429def expand_column_list_from_order_by(collist, order_by):
430 """Given the columns clause and ORDER BY of a selectable,
431 return a list of column expressions that can be added to the collist
432 corresponding to the ORDER BY, without repeating those already
433 in the collist.
434
435 """
436 cols_already_present = {
437 col.element if col._order_by_label_element is not None else col
438 for col in collist
439 }
440
441 to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
442
443 return [col for col in to_look_for if col not in cols_already_present]
444
445
446def clause_is_present(clause, search):
447 """Given a target clause and a second to search within, return True
448 if the target is plainly present in the search without any
449 subqueries or aliases involved.
450
451 Basically descends through Joins.
452
453 """
454
455 for elem in surface_selectables(search):
456 if clause == elem: # use == here so that Annotated's compare
457 return True
458 else:
459 return False
460
461
462def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
463 if isinstance(clause, Join):
464 yield from tables_from_leftmost(clause.left)
465 yield from tables_from_leftmost(clause.right)
466 elif isinstance(clause, FromGrouping):
467 yield from tables_from_leftmost(clause.element)
468 else:
469 yield clause
470
471
472def surface_selectables(clause):
473 stack = [clause]
474 while stack:
475 elem = stack.pop()
476 yield elem
477 if isinstance(elem, Join):
478 stack.extend((elem.left, elem.right))
479 elif isinstance(elem, FromGrouping):
480 stack.append(elem.element)
481
482
483def surface_selectables_only(clause: ClauseElement) -> Iterator[ClauseElement]:
484 stack = [clause]
485 while stack:
486 elem = stack.pop()
487 if isinstance(elem, (TableClause, Alias)):
488 yield elem
489 if isinstance(elem, Join):
490 stack.extend((elem.left, elem.right))
491 elif isinstance(elem, FromGrouping):
492 stack.append(elem.element)
493 elif isinstance(elem, ColumnClause):
494 if elem.table is not None:
495 stack.append(elem.table)
496 else:
497 yield elem
498 elif elem is not None:
499 yield elem
500
501
502def extract_first_column_annotation(column, annotation_name):
503 filter_ = (FromGrouping, SelectBase)
504
505 stack = deque([column])
506 while stack:
507 elem = stack.popleft()
508 if annotation_name in elem._annotations:
509 return elem._annotations[annotation_name]
510 for sub in elem.get_children():
511 if isinstance(sub, filter_):
512 continue
513 stack.append(sub)
514 return None
515
516
517def selectables_overlap(left: FromClause, right: FromClause) -> bool:
518 """Return True if left/right have some overlapping selectable"""
519
520 return bool(
521 set(surface_selectables(left)).intersection(surface_selectables(right))
522 )
523
524
525def bind_values(clause):
526 """Return an ordered list of "bound" values in the given clause.
527
528 E.g.::
529
530 >>> expr = and_(table.c.foo == 5, table.c.foo == 7)
531 >>> bind_values(expr)
532 [5, 7]
533 """
534
535 v = []
536
537 def visit_bindparam(bind):
538 v.append(bind.effective_value)
539
540 visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
541 return v
542
543
544def _quote_ddl_expr(element):
545 if isinstance(element, str):
546 element = element.replace("'", "''")
547 return "'%s'" % element
548 else:
549 return repr(element)
550
551
552class _repr_base:
553 _LIST: int = 0
554 _TUPLE: int = 1
555 _DICT: int = 2
556
557 __slots__ = ("max_chars",)
558
559 max_chars: int
560
561 def trunc(self, value: Any) -> str:
562 rep = repr(value)
563 lenrep = len(rep)
564 if lenrep > self.max_chars:
565 segment_length = self.max_chars // 2
566 rep = (
567 rep[0:segment_length]
568 + (
569 " ... (%d characters truncated) ... "
570 % (lenrep - self.max_chars)
571 )
572 + rep[-segment_length:]
573 )
574 return rep
575
576
577def _repr_single_value(value):
578 rp = _repr_base()
579 rp.max_chars = 300
580 return rp.trunc(value)
581
582
583class _repr_row(_repr_base):
584 """Provide a string view of a row."""
585
586 __slots__ = ("row",)
587
588 def __init__(self, row: Row[Any], max_chars: int = 300):
589 self.row = row
590 self.max_chars = max_chars
591
592 def __repr__(self) -> str:
593 trunc = self.trunc
594 return "(%s%s)" % (
595 ", ".join(trunc(value) for value in self.row),
596 "," if len(self.row) == 1 else "",
597 )
598
599
600class _long_statement(str):
601 def __str__(self) -> str:
602 lself = len(self)
603 if lself > 500:
604 lleft = 250
605 lright = 100
606 trunc = lself - lleft - lright
607 return (
608 f"{self[0:lleft]} ... {trunc} "
609 f"characters truncated ... {self[-lright:]}"
610 )
611 else:
612 return str.__str__(self)
613
614
615class _repr_params(_repr_base):
616 """Provide a string view of bound parameters.
617
618 Truncates display to a given number of 'multi' parameter sets,
619 as well as long values to a given number of characters.
620
621 """
622
623 __slots__ = "params", "batches", "ismulti", "max_params"
624
625 def __init__(
626 self,
627 params: Optional[_AnyExecuteParams],
628 batches: int,
629 max_params: int = 100,
630 max_chars: int = 300,
631 ismulti: Optional[bool] = None,
632 ):
633 self.params = params
634 self.ismulti = ismulti
635 self.batches = batches
636 self.max_chars = max_chars
637 self.max_params = max_params
638
639 def __repr__(self) -> str:
640 if self.ismulti is None:
641 return self.trunc(self.params)
642
643 if isinstance(self.params, list):
644 typ = self._LIST
645
646 elif isinstance(self.params, tuple):
647 typ = self._TUPLE
648 elif isinstance(self.params, dict):
649 typ = self._DICT
650 else:
651 return self.trunc(self.params)
652
653 if self.ismulti:
654 multi_params = cast(
655 "_AnyMultiExecuteParams",
656 self.params,
657 )
658
659 if len(self.params) > self.batches:
660 msg = (
661 " ... displaying %i of %i total bound parameter sets ... "
662 )
663 return " ".join(
664 (
665 self._repr_multi(
666 multi_params[: self.batches - 2],
667 typ,
668 )[0:-1],
669 msg % (self.batches, len(self.params)),
670 self._repr_multi(multi_params[-2:], typ)[1:],
671 )
672 )
673 else:
674 return self._repr_multi(multi_params, typ)
675 else:
676 return self._repr_params(
677 cast(
678 "_AnySingleExecuteParams",
679 self.params,
680 ),
681 typ,
682 )
683
684 def _repr_multi(
685 self,
686 multi_params: _AnyMultiExecuteParams,
687 typ: int,
688 ) -> str:
689 if multi_params:
690 if isinstance(multi_params[0], list):
691 elem_type = self._LIST
692 elif isinstance(multi_params[0], tuple):
693 elem_type = self._TUPLE
694 elif isinstance(multi_params[0], dict):
695 elem_type = self._DICT
696 else:
697 assert False, "Unknown parameter type %s" % (
698 type(multi_params[0])
699 )
700
701 elements = ", ".join(
702 self._repr_params(params, elem_type) for params in multi_params
703 )
704 else:
705 elements = ""
706
707 if typ == self._LIST:
708 return "[%s]" % elements
709 else:
710 return "(%s)" % elements
711
712 def _get_batches(self, params: Iterable[Any]) -> Any:
713 lparams = list(params)
714 lenparams = len(lparams)
715 if lenparams > self.max_params:
716 lleft = self.max_params // 2
717 return (
718 lparams[0:lleft],
719 lparams[-lleft:],
720 lenparams - self.max_params,
721 )
722 else:
723 return lparams, None, None
724
725 def _repr_params(
726 self,
727 params: _AnySingleExecuteParams,
728 typ: int,
729 ) -> str:
730 if typ is self._DICT:
731 return self._repr_param_dict(
732 cast("_CoreSingleExecuteParams", params)
733 )
734 elif typ is self._TUPLE:
735 return self._repr_param_tuple(cast("Sequence[Any]", params))
736 else:
737 return self._repr_param_list(params)
738
739 def _repr_param_dict(self, params: _CoreSingleExecuteParams) -> str:
740 trunc = self.trunc
741 (
742 items_first_batch,
743 items_second_batch,
744 trunclen,
745 ) = self._get_batches(params.items())
746
747 if items_second_batch:
748 text = "{%s" % (
749 ", ".join(
750 f"{key!r}: {trunc(value)}"
751 for key, value in items_first_batch
752 )
753 )
754 text += f" ... {trunclen} parameters truncated ... "
755 text += "%s}" % (
756 ", ".join(
757 f"{key!r}: {trunc(value)}"
758 for key, value in items_second_batch
759 )
760 )
761 else:
762 text = "{%s}" % (
763 ", ".join(
764 f"{key!r}: {trunc(value)}"
765 for key, value in items_first_batch
766 )
767 )
768 return text
769
770 def _repr_param_tuple(self, params: Sequence[Any]) -> str:
771 trunc = self.trunc
772
773 (
774 items_first_batch,
775 items_second_batch,
776 trunclen,
777 ) = self._get_batches(params)
778
779 if items_second_batch:
780 text = "(%s" % (
781 ", ".join(trunc(value) for value in items_first_batch)
782 )
783 text += f" ... {trunclen} parameters truncated ... "
784 text += "%s)" % (
785 ", ".join(trunc(value) for value in items_second_batch),
786 )
787 else:
788 text = "(%s%s)" % (
789 ", ".join(trunc(value) for value in items_first_batch),
790 "," if len(items_first_batch) == 1 else "",
791 )
792 return text
793
794 def _repr_param_list(self, params: _AnySingleExecuteParams) -> str:
795 trunc = self.trunc
796 (
797 items_first_batch,
798 items_second_batch,
799 trunclen,
800 ) = self._get_batches(params)
801
802 if items_second_batch:
803 text = "[%s" % (
804 ", ".join(trunc(value) for value in items_first_batch)
805 )
806 text += f" ... {trunclen} parameters truncated ... "
807 text += "%s]" % (
808 ", ".join(trunc(value) for value in items_second_batch)
809 )
810 else:
811 text = "[%s]" % (
812 ", ".join(trunc(value) for value in items_first_batch)
813 )
814 return text
815
816
817def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
818 """given criterion containing bind params, convert selected elements
819 to IS NULL.
820
821 """
822
823 def visit_binary(binary):
824 if (
825 isinstance(binary.left, BindParameter)
826 and binary.left._identifying_key in nulls
827 ):
828 # reverse order if the NULL is on the left side
829 binary.left = binary.right
830 binary.right = Null()
831 binary.operator = operators.is_
832 binary.negate = operators.is_not
833 elif (
834 isinstance(binary.right, BindParameter)
835 and binary.right._identifying_key in nulls
836 ):
837 binary.right = Null()
838 binary.operator = operators.is_
839 binary.negate = operators.is_not
840
841 return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
842
843
844def splice_joins(
845 left: Optional[FromClause],
846 right: Optional[FromClause],
847 stop_on: Optional[FromClause] = None,
848) -> Optional[FromClause]:
849 if left is None:
850 return right
851
852 stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)]
853
854 adapter = ClauseAdapter(left)
855 ret = None
856 while stack:
857 (right, prevright) = stack.pop()
858 if isinstance(right, Join) and right is not stop_on:
859 right = right._clone()
860 right.onclause = adapter.traverse(right.onclause)
861 stack.append((right.left, right))
862 else:
863 right = adapter.traverse(right)
864 if prevright is not None:
865 assert right is not None
866 prevright.left = right
867 if ret is None:
868 ret = right
869
870 return ret
871
872
873@overload
874def reduce_columns(
875 columns: Iterable[ColumnElement[Any]],
876 *clauses: Optional[ClauseElement],
877 **kw: bool,
878) -> Sequence[ColumnElement[Any]]: ...
879
880
881@overload
882def reduce_columns(
883 columns: _SelectIterable,
884 *clauses: Optional[ClauseElement],
885 **kw: bool,
886) -> Sequence[Union[ColumnElement[Any], TextClause]]: ...
887
888
889def reduce_columns(
890 columns: _SelectIterable,
891 *clauses: Optional[ClauseElement],
892 **kw: bool,
893) -> Collection[Union[ColumnElement[Any], TextClause]]:
894 r"""given a list of columns, return a 'reduced' set based on natural
895 equivalents.
896
897 the set is reduced to the smallest list of columns which have no natural
898 equivalent present in the list. A "natural equivalent" means that two
899 columns will ultimately represent the same value because they are related
900 by a foreign key.
901
902 \*clauses is an optional list of join clauses which will be traversed
903 to further identify columns that are "equivalent".
904
905 \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
906 whose tables are not yet configured, or columns that aren't yet present.
907
908 This function is primarily used to determine the most minimal "primary
909 key" from a selectable, by reducing the set of primary key columns present
910 in the selectable to just those that are not repeated.
911
912 """
913 ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
914 only_synonyms = kw.pop("only_synonyms", False)
915
916 column_set = util.OrderedSet(columns)
917 cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
918 c for c in column_set if is_text_clause(c) # type: ignore
919 )
920
921 omit = util.column_set()
922 for col in cset_no_text:
923 for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
924 for c in cset_no_text:
925 if c is col:
926 continue
927 try:
928 fk_col = fk.column
929 except exc.NoReferencedColumnError:
930 # TODO: add specific coverage here
931 # to test/sql/test_selectable ReduceTest
932 if ignore_nonexistent_tables:
933 continue
934 else:
935 raise
936 except exc.NoReferencedTableError:
937 # TODO: add specific coverage here
938 # to test/sql/test_selectable ReduceTest
939 if ignore_nonexistent_tables:
940 continue
941 else:
942 raise
943 if fk_col.shares_lineage(c) and (
944 not only_synonyms or c.name == col.name
945 ):
946 omit.add(col)
947 break
948
949 if clauses:
950
951 def visit_binary(binary):
952 if binary.operator == operators.eq:
953 cols = util.column_set(
954 chain(
955 *[c.proxy_set for c in cset_no_text.difference(omit)]
956 )
957 )
958 if binary.left in cols and binary.right in cols:
959 for c in reversed(cset_no_text):
960 if c.shares_lineage(binary.right) and (
961 not only_synonyms or c.name == binary.left.name
962 ):
963 omit.add(c)
964 break
965
966 for clause in clauses:
967 if clause is not None:
968 visitors.traverse(clause, {}, {"binary": visit_binary})
969
970 return column_set.difference(omit)
971
972
973def criterion_as_pairs(
974 expression,
975 consider_as_foreign_keys=None,
976 consider_as_referenced_keys=None,
977 any_operator=False,
978):
979 """traverse an expression and locate binary criterion pairs."""
980
981 if consider_as_foreign_keys and consider_as_referenced_keys:
982 raise exc.ArgumentError(
983 "Can only specify one of "
984 "'consider_as_foreign_keys' or "
985 "'consider_as_referenced_keys'"
986 )
987
988 def col_is(a, b):
989 # return a is b
990 return a.compare(b)
991
992 def visit_binary(binary):
993 if not any_operator and binary.operator is not operators.eq:
994 return
995 if not isinstance(binary.left, ColumnElement) or not isinstance(
996 binary.right, ColumnElement
997 ):
998 return
999
1000 if consider_as_foreign_keys:
1001 if binary.left in consider_as_foreign_keys and (
1002 col_is(binary.right, binary.left)
1003 or binary.right not in consider_as_foreign_keys
1004 ):
1005 pairs.append((binary.right, binary.left))
1006 elif binary.right in consider_as_foreign_keys and (
1007 col_is(binary.left, binary.right)
1008 or binary.left not in consider_as_foreign_keys
1009 ):
1010 pairs.append((binary.left, binary.right))
1011 elif consider_as_referenced_keys:
1012 if binary.left in consider_as_referenced_keys and (
1013 col_is(binary.right, binary.left)
1014 or binary.right not in consider_as_referenced_keys
1015 ):
1016 pairs.append((binary.left, binary.right))
1017 elif binary.right in consider_as_referenced_keys and (
1018 col_is(binary.left, binary.right)
1019 or binary.left not in consider_as_referenced_keys
1020 ):
1021 pairs.append((binary.right, binary.left))
1022 else:
1023 if isinstance(binary.left, Column) and isinstance(
1024 binary.right, Column
1025 ):
1026 if binary.left.references(binary.right):
1027 pairs.append((binary.right, binary.left))
1028 elif binary.right.references(binary.left):
1029 pairs.append((binary.left, binary.right))
1030
1031 pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = []
1032 visitors.traverse(expression, {}, {"binary": visit_binary})
1033 return pairs
1034
1035
1036class ClauseAdapter(visitors.ReplacingExternalTraversal):
1037 """Clones and modifies clauses based on column correspondence.
1038
1039 E.g.::
1040
1041 table1 = Table(
1042 "sometable",
1043 metadata,
1044 Column("col1", Integer),
1045 Column("col2", Integer),
1046 )
1047 table2 = Table(
1048 "someothertable",
1049 metadata,
1050 Column("col1", Integer),
1051 Column("col2", Integer),
1052 )
1053
1054 condition = table1.c.col1 == table2.c.col1
1055
1056 make an alias of table1::
1057
1058 s = table1.alias("foo")
1059
1060 calling ``ClauseAdapter(s).traverse(condition)`` converts
1061 condition to read::
1062
1063 s.c.col1 == table2.c.col1
1064
1065 """
1066
1067 __slots__ = (
1068 "__traverse_options__",
1069 "selectable",
1070 "include_fn",
1071 "exclude_fn",
1072 "equivalents",
1073 "adapt_on_names",
1074 "adapt_from_selectables",
1075 )
1076
1077 def __init__(
1078 self,
1079 selectable: Selectable,
1080 equivalents: Optional[_EquivalentColumnMap] = None,
1081 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1082 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1083 adapt_on_names: bool = False,
1084 anonymize_labels: bool = False,
1085 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1086 ):
1087 self.__traverse_options__ = {
1088 "stop_on": [selectable],
1089 "anonymize_labels": anonymize_labels,
1090 }
1091 self.selectable = selectable
1092 self.include_fn = include_fn
1093 self.exclude_fn = exclude_fn
1094 self.equivalents = util.column_dict(equivalents or {})
1095 self.adapt_on_names = adapt_on_names
1096 self.adapt_from_selectables = adapt_from_selectables
1097
1098 if TYPE_CHECKING:
1099
1100 @overload
1101 def traverse(self, obj: Literal[None]) -> None: ...
1102
1103 # note this specializes the ReplacingExternalTraversal.traverse()
1104 # method to state
1105 # that we will return the same kind of ExternalTraversal object as
1106 # we were given. This is probably not 100% true, such as it's
1107 # possible for us to swap out Alias for Table at the top level.
1108 # Ideally there could be overloads specific to ColumnElement and
1109 # FromClause but Mypy is not accepting those as compatible with
1110 # the base ReplacingExternalTraversal
1111 @overload
1112 def traverse(self, obj: _ET) -> _ET: ...
1113
1114 def traverse(
1115 self, obj: Optional[ExternallyTraversible]
1116 ) -> Optional[ExternallyTraversible]: ...
1117
1118 def _corresponding_column(
1119 self, col, require_embedded, _seen=util.EMPTY_SET
1120 ):
1121 newcol = self.selectable.corresponding_column(
1122 col, require_embedded=require_embedded
1123 )
1124 if newcol is None and col in self.equivalents and col not in _seen:
1125 for equiv in self.equivalents[col]:
1126 newcol = self._corresponding_column(
1127 equiv,
1128 require_embedded=require_embedded,
1129 _seen=_seen.union([col]),
1130 )
1131 if newcol is not None:
1132 return newcol
1133
1134 if (
1135 self.adapt_on_names
1136 and newcol is None
1137 and isinstance(col, NamedColumn)
1138 ):
1139 newcol = self.selectable.exported_columns.get(col.name)
1140 return newcol
1141
1142 @util.preload_module("sqlalchemy.sql.functions")
1143 def replace(
1144 self, col: _ET, _include_singleton_constants: bool = False
1145 ) -> Optional[_ET]:
1146 functions = util.preloaded.sql_functions
1147
1148 # TODO: cython candidate
1149
1150 if self.include_fn and not self.include_fn(col): # type: ignore
1151 return None
1152 elif self.exclude_fn and self.exclude_fn(col): # type: ignore
1153 return None
1154
1155 if isinstance(col, FromClause) and not isinstance(
1156 col, functions.FunctionElement
1157 ):
1158 if self.selectable.is_derived_from(col):
1159 if self.adapt_from_selectables:
1160 for adp in self.adapt_from_selectables:
1161 if adp.is_derived_from(col):
1162 break
1163 else:
1164 return None
1165 return self.selectable # type: ignore
1166 elif isinstance(col, Alias) and isinstance(
1167 col.element, TableClause
1168 ):
1169 # we are a SELECT statement and not derived from an alias of a
1170 # table (which nonetheless may be a table our SELECT derives
1171 # from), so return the alias to prevent further traversal
1172 # or
1173 # we are an alias of a table and we are not derived from an
1174 # alias of a table (which nonetheless may be the same table
1175 # as ours) so, same thing
1176 return col # type: ignore
1177 else:
1178 # other cases where we are a selectable and the element
1179 # is another join or selectable that contains a table which our
1180 # selectable derives from, that we want to process
1181 return None
1182
1183 elif not isinstance(col, ColumnElement):
1184 return None
1185 elif not _include_singleton_constants and col._is_singleton_constant:
1186 # dont swap out NULL, TRUE, FALSE for a label name
1187 # in a SQL statement that's being rewritten,
1188 # leave them as the constant. This is first noted in #6259,
1189 # however the logic to check this moved here as of #7154 so that
1190 # it is made specific to SQL rewriting and not all column
1191 # correspondence
1192
1193 return None
1194
1195 if "adapt_column" in col._annotations:
1196 col = col._annotations["adapt_column"]
1197
1198 if TYPE_CHECKING:
1199 assert isinstance(col, KeyedColumnElement)
1200
1201 if self.adapt_from_selectables and col not in self.equivalents:
1202 for adp in self.adapt_from_selectables:
1203 if adp.c.corresponding_column(col, False) is not None:
1204 break
1205 else:
1206 return None
1207
1208 if TYPE_CHECKING:
1209 assert isinstance(col, KeyedColumnElement)
1210
1211 return self._corresponding_column( # type: ignore
1212 col, require_embedded=True
1213 )
1214
1215
1216class _ColumnLookup(Protocol):
1217 @overload
1218 def __getitem__(self, key: None) -> None: ...
1219
1220 @overload
1221 def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: ...
1222
1223 @overload
1224 def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: ...
1225
1226 @overload
1227 def __getitem__(self, key: _ET) -> _ET: ...
1228
1229 def __getitem__(self, key: Any) -> Any: ...
1230
1231
1232class ColumnAdapter(ClauseAdapter):
1233 """Extends ClauseAdapter with extra utility functions.
1234
1235 Key aspects of ColumnAdapter include:
1236
1237 * Expressions that are adapted are stored in a persistent
1238 .columns collection; so that an expression E adapted into
1239 an expression E1, will return the same object E1 when adapted
1240 a second time. This is important in particular for things like
1241 Label objects that are anonymized, so that the ColumnAdapter can
1242 be used to present a consistent "adapted" view of things.
1243
1244 * Exclusion of items from the persistent collection based on
1245 include/exclude rules, but also independent of hash identity.
1246 This because "annotated" items all have the same hash identity as their
1247 parent.
1248
1249 * "wrapping" capability is added, so that the replacement of an expression
1250 E can proceed through a series of adapters. This differs from the
1251 visitor's "chaining" feature in that the resulting object is passed
1252 through all replacing functions unconditionally, rather than stopping
1253 at the first one that returns non-None.
1254
1255 * An adapt_required option, used by eager loading to indicate that
1256 We don't trust a result row column that is not translated.
1257 This is to prevent a column from being interpreted as that
1258 of the child row in a self-referential scenario, see
1259 inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
1260
1261 """
1262
1263 __slots__ = (
1264 "columns",
1265 "adapt_required",
1266 "allow_label_resolve",
1267 "_wrap",
1268 "__weakref__",
1269 )
1270
1271 columns: _ColumnLookup
1272
1273 def __init__(
1274 self,
1275 selectable: Selectable,
1276 equivalents: Optional[_EquivalentColumnMap] = None,
1277 adapt_required: bool = False,
1278 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1279 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1280 adapt_on_names: bool = False,
1281 allow_label_resolve: bool = True,
1282 anonymize_labels: bool = False,
1283 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1284 ):
1285 super().__init__(
1286 selectable,
1287 equivalents,
1288 include_fn=include_fn,
1289 exclude_fn=exclude_fn,
1290 adapt_on_names=adapt_on_names,
1291 anonymize_labels=anonymize_labels,
1292 adapt_from_selectables=adapt_from_selectables,
1293 )
1294
1295 self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
1296 if self.include_fn or self.exclude_fn:
1297 self.columns = self._IncludeExcludeMapping(self, self.columns)
1298 self.adapt_required = adapt_required
1299 self.allow_label_resolve = allow_label_resolve
1300 self._wrap = None
1301
1302 class _IncludeExcludeMapping:
1303 def __init__(self, parent, columns):
1304 self.parent = parent
1305 self.columns = columns
1306
1307 def __getitem__(self, key):
1308 if (
1309 self.parent.include_fn and not self.parent.include_fn(key)
1310 ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
1311 if self.parent._wrap:
1312 return self.parent._wrap.columns[key]
1313 else:
1314 return key
1315 return self.columns[key]
1316
1317 def wrap(self, adapter):
1318 ac = copy.copy(self)
1319 ac._wrap = adapter
1320 ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore
1321 if ac.include_fn or ac.exclude_fn:
1322 ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
1323
1324 return ac
1325
1326 @overload
1327 def traverse(self, obj: Literal[None]) -> None: ...
1328
1329 @overload
1330 def traverse(self, obj: _ET) -> _ET: ...
1331
1332 def traverse(
1333 self, obj: Optional[ExternallyTraversible]
1334 ) -> Optional[ExternallyTraversible]:
1335 return self.columns[obj]
1336
1337 def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
1338 assert isinstance(visitor, ColumnAdapter)
1339
1340 return super().chain(visitor)
1341
1342 if TYPE_CHECKING:
1343
1344 @property
1345 def visitor_iterator(self) -> Iterator[ColumnAdapter]: ...
1346
1347 adapt_clause = traverse
1348 adapt_list = ClauseAdapter.copy_and_process
1349
1350 def adapt_check_present(
1351 self, col: ColumnElement[Any]
1352 ) -> Optional[ColumnElement[Any]]:
1353 newcol = self.columns[col]
1354
1355 if newcol is col and self._corresponding_column(col, True) is None:
1356 return None
1357
1358 return newcol
1359
1360 def _locate_col(
1361 self, col: ColumnElement[Any]
1362 ) -> Optional[ColumnElement[Any]]:
1363 # both replace and traverse() are overly complicated for what
1364 # we are doing here and we would do better to have an inlined
1365 # version that doesn't build up as much overhead. the issue is that
1366 # sometimes the lookup does in fact have to adapt the insides of
1367 # say a labeled scalar subquery. However, if the object is an
1368 # Immutable, i.e. Column objects, we can skip the "clone" /
1369 # "copy internals" part since those will be no-ops in any case.
1370 # additionally we want to catch singleton objects null/true/false
1371 # and make sure they are adapted as well here.
1372
1373 if col._is_immutable:
1374 for vis in self.visitor_iterator:
1375 c = vis.replace(col, _include_singleton_constants=True)
1376 if c is not None:
1377 break
1378 else:
1379 c = col
1380 else:
1381 c = ClauseAdapter.traverse(self, col)
1382
1383 if self._wrap:
1384 c2 = self._wrap._locate_col(c)
1385 if c2 is not None:
1386 c = c2
1387
1388 if self.adapt_required and c is col:
1389 return None
1390
1391 # allow_label_resolve is consumed by one case for joined eager loading
1392 # as part of its logic to prevent its own columns from being affected
1393 # by .order_by(). Before full typing were applied to the ORM, this
1394 # logic would set this attribute on the incoming object (which is
1395 # typically a column, but we have a test for it being a non-column
1396 # object) if no column were found. While this seemed to
1397 # have no negative effects, this adjustment should only occur on the
1398 # new column which is assumed to be local to an adapted selectable.
1399 if c is not col:
1400 c._allow_label_resolve = self.allow_label_resolve
1401
1402 return c
1403
1404
1405def _offset_or_limit_clause(
1406 element: _LimitOffsetType,
1407 name: Optional[str] = None,
1408 type_: Optional[_TypeEngineArgument[int]] = None,
1409) -> ColumnElement[int]:
1410 """Convert the given value to an "offset or limit" clause.
1411
1412 This handles incoming integers and converts to an expression; if
1413 an expression is already given, it is passed through.
1414
1415 """
1416 return coercions.expect(
1417 roles.LimitOffsetRole, element, name=name, type_=type_
1418 )
1419
1420
1421def _offset_or_limit_clause_asint_if_possible(
1422 clause: _LimitOffsetType,
1423) -> _LimitOffsetType:
1424 """Return the offset or limit clause as a simple integer if possible,
1425 else return the clause.
1426
1427 """
1428 if clause is None:
1429 return None
1430 if hasattr(clause, "_limit_offset_value"):
1431 value = clause._limit_offset_value
1432 return util.asint(value)
1433 else:
1434 return clause
1435
1436
1437def _make_slice(
1438 limit_clause: _LimitOffsetType,
1439 offset_clause: _LimitOffsetType,
1440 start: int,
1441 stop: int,
1442) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]:
1443 """Compute LIMIT/OFFSET in terms of slice start/end"""
1444
1445 # for calculated limit/offset, try to do the addition of
1446 # values to offset in Python, however if a SQL clause is present
1447 # then the addition has to be on the SQL side.
1448
1449 # TODO: typing is finding a few gaps in here, see if they can be
1450 # closed up
1451
1452 if start is not None and stop is not None:
1453 offset_clause = _offset_or_limit_clause_asint_if_possible(
1454 offset_clause
1455 )
1456 if offset_clause is None:
1457 offset_clause = 0
1458
1459 if start != 0:
1460 offset_clause = offset_clause + start # type: ignore
1461
1462 if offset_clause == 0:
1463 offset_clause = None
1464 else:
1465 assert offset_clause is not None
1466 offset_clause = _offset_or_limit_clause(offset_clause)
1467
1468 limit_clause = _offset_or_limit_clause(stop - start)
1469
1470 elif start is None and stop is not None:
1471 limit_clause = _offset_or_limit_clause(stop)
1472 elif start is not None and stop is None:
1473 offset_clause = _offset_or_limit_clause_asint_if_possible(
1474 offset_clause
1475 )
1476 if offset_clause is None:
1477 offset_clause = 0
1478
1479 if start != 0:
1480 offset_clause = offset_clause + start
1481
1482 if offset_clause == 0:
1483 offset_clause = None
1484 else:
1485 offset_clause = _offset_or_limit_clause(offset_clause)
1486
1487 return limit_clause, offset_clause