1# sql/util.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""High level utilities which build upon other modules here.
10
11"""
12from __future__ import annotations
13
14from collections import deque
15import copy
16from itertools import chain
17import typing
18from typing import AbstractSet
19from typing import Any
20from typing import Callable
21from typing import cast
22from typing import Collection
23from typing import Dict
24from typing import Iterable
25from typing import Iterator
26from typing import List
27from typing import Optional
28from typing import overload
29from typing import Protocol
30from typing import Sequence
31from typing import Tuple
32from typing import TYPE_CHECKING
33from typing import TypeVar
34from typing import Union
35
36from . import coercions
37from . import operators
38from . import roles
39from . import visitors
40from ._typing import is_text_clause
41from .annotation import _deep_annotate as _deep_annotate # noqa: F401
42from .annotation import _deep_deannotate as _deep_deannotate # noqa: F401
43from .annotation import _shallow_annotate as _shallow_annotate # noqa: F401
44from .base import _expand_cloned
45from .base import _from_objects
46from .cache_key import HasCacheKey as HasCacheKey # noqa: F401
47from .ddl import sort_tables as sort_tables # noqa: F401
48from .elements import _find_columns as _find_columns
49from .elements import _label_reference
50from .elements import _textual_label_reference
51from .elements import BindParameter
52from .elements import ClauseElement
53from .elements import ColumnClause
54from .elements import ColumnElement
55from .elements import Grouping
56from .elements import KeyedColumnElement
57from .elements import Label
58from .elements import NamedColumn
59from .elements import Null
60from .elements import UnaryExpression
61from .schema import Column
62from .selectable import Alias
63from .selectable import FromClause
64from .selectable import FromGrouping
65from .selectable import Join
66from .selectable import ScalarSelect
67from .selectable import SelectBase
68from .selectable import TableClause
69from .visitors import _ET
70from .. import exc
71from .. import util
72from ..util.typing import Literal
73from ..util.typing import Unpack
74
75if typing.TYPE_CHECKING:
76 from ._typing import _EquivalentColumnMap
77 from ._typing import _LimitOffsetType
78 from ._typing import _TypeEngineArgument
79 from .elements import BinaryExpression
80 from .elements import TextClause
81 from .selectable import _JoinTargetElement
82 from .selectable import _SelectIterable
83 from .selectable import Selectable
84 from .visitors import _TraverseCallableType
85 from .visitors import ExternallyTraversible
86 from .visitors import ExternalTraversal
87 from ..engine.interfaces import _AnyExecuteParams
88 from ..engine.interfaces import _AnyMultiExecuteParams
89 from ..engine.interfaces import _AnySingleExecuteParams
90 from ..engine.interfaces import _CoreSingleExecuteParams
91 from ..engine.row import Row
92
93_CE = TypeVar("_CE", bound="ColumnElement[Any]")
94
95
96def join_condition(
97 a: FromClause,
98 b: FromClause,
99 a_subset: Optional[FromClause] = None,
100 consider_as_foreign_keys: Optional[AbstractSet[ColumnClause[Any]]] = None,
101) -> ColumnElement[bool]:
102 """Create a join condition between two tables or selectables.
103
104 e.g.::
105
106 join_condition(tablea, tableb)
107
108 would produce an expression along the lines of::
109
110 tablea.c.id==tableb.c.tablea_id
111
112 The join is determined based on the foreign key relationships
113 between the two selectables. If there are multiple ways
114 to join, or no way to join, an error is raised.
115
116 :param a_subset: An optional expression that is a sub-component
117 of ``a``. An attempt will be made to join to just this sub-component
118 first before looking at the full ``a`` construct, and if found
119 will be successful even if there are other ways to join to ``a``.
120 This allows the "right side" of a join to be passed thereby
121 providing a "natural join".
122
123 """
124 return Join._join_condition(
125 a,
126 b,
127 a_subset=a_subset,
128 consider_as_foreign_keys=consider_as_foreign_keys,
129 )
130
131
132def find_join_source(
133 clauses: List[FromClause], join_to: FromClause
134) -> List[int]:
135 """Given a list of FROM clauses and a selectable,
136 return the first index and element from the list of
137 clauses which can be joined against the selectable. returns
138 None, None if no match is found.
139
140 e.g.::
141
142 clause1 = table1.join(table2)
143 clause2 = table4.join(table5)
144
145 join_to = table2.join(table3)
146
147 find_join_source([clause1, clause2], join_to) == clause1
148
149 """
150
151 selectables = list(_from_objects(join_to))
152 idx = []
153 for i, f in enumerate(clauses):
154 for s in selectables:
155 if f.is_derived_from(s):
156 idx.append(i)
157 return idx
158
159
160def find_left_clause_that_matches_given(
161 clauses: Sequence[FromClause], join_from: FromClause
162) -> List[int]:
163 """Given a list of FROM clauses and a selectable,
164 return the indexes from the list of
165 clauses which is derived from the selectable.
166
167 """
168
169 selectables = list(_from_objects(join_from))
170 liberal_idx = []
171 for i, f in enumerate(clauses):
172 for s in selectables:
173 # basic check, if f is derived from s.
174 # this can be joins containing a table, or an aliased table
175 # or select statement matching to a table. This check
176 # will match a table to a selectable that is adapted from
177 # that table. With Query, this suits the case where a join
178 # is being made to an adapted entity
179 if f.is_derived_from(s):
180 liberal_idx.append(i)
181 break
182
183 # in an extremely small set of use cases, a join is being made where
184 # there are multiple FROM clauses where our target table is represented
185 # in more than one, such as embedded or similar. in this case, do
186 # another pass where we try to get a more exact match where we aren't
187 # looking at adaption relationships.
188 if len(liberal_idx) > 1:
189 conservative_idx = []
190 for idx in liberal_idx:
191 f = clauses[idx]
192 for s in selectables:
193 if set(surface_selectables(f)).intersection(
194 surface_selectables(s)
195 ):
196 conservative_idx.append(idx)
197 break
198 if conservative_idx:
199 return conservative_idx
200
201 return liberal_idx
202
203
204def find_left_clause_to_join_from(
205 clauses: Sequence[FromClause],
206 join_to: _JoinTargetElement,
207 onclause: Optional[ColumnElement[Any]],
208) -> List[int]:
209 """Given a list of FROM clauses, a selectable,
210 and optional ON clause, return a list of integer indexes from the
211 clauses list indicating the clauses that can be joined from.
212
213 The presence of an "onclause" indicates that at least one clause can
214 definitely be joined from; if the list of clauses is of length one
215 and the onclause is given, returns that index. If the list of clauses
216 is more than length one, and the onclause is given, attempts to locate
217 which clauses contain the same columns.
218
219 """
220 idx = []
221 selectables = set(_from_objects(join_to))
222
223 # if we are given more than one target clause to join
224 # from, use the onclause to provide a more specific answer.
225 # otherwise, don't try to limit, after all, "ON TRUE" is a valid
226 # on clause
227 if len(clauses) > 1 and onclause is not None:
228 resolve_ambiguity = True
229 cols_in_onclause = _find_columns(onclause)
230 else:
231 resolve_ambiguity = False
232 cols_in_onclause = None
233
234 for i, f in enumerate(clauses):
235 for s in selectables.difference([f]):
236 if resolve_ambiguity:
237 assert cols_in_onclause is not None
238 if set(f.c).union(s.c).issuperset(cols_in_onclause):
239 idx.append(i)
240 break
241 elif onclause is not None or Join._can_join(f, s):
242 idx.append(i)
243 break
244
245 if len(idx) > 1:
246 # this is the same "hide froms" logic from
247 # Selectable._get_display_froms
248 toremove = set(
249 chain(*[_expand_cloned(f._hide_froms) for f in clauses])
250 )
251 idx = [i for i in idx if clauses[i] not in toremove]
252
253 # onclause was given and none of them resolved, so assume
254 # all indexes can match
255 if not idx and onclause is not None:
256 return list(range(len(clauses)))
257 else:
258 return idx
259
260
261def visit_binary_product(
262 fn: Callable[
263 [BinaryExpression[Any], ColumnElement[Any], ColumnElement[Any]], None
264 ],
265 expr: ColumnElement[Any],
266) -> None:
267 """Produce a traversal of the given expression, delivering
268 column comparisons to the given function.
269
270 The function is of the form::
271
272 def my_fn(binary, left, right)
273
274 For each binary expression located which has a
275 comparison operator, the product of "left" and
276 "right" will be delivered to that function,
277 in terms of that binary.
278
279 Hence an expression like::
280
281 and_(
282 (a + b) == q + func.sum(e + f),
283 j == r
284 )
285
286 would have the traversal::
287
288 a <eq> q
289 a <eq> e
290 a <eq> f
291 b <eq> q
292 b <eq> e
293 b <eq> f
294 j <eq> r
295
296 That is, every combination of "left" and
297 "right" that doesn't further contain
298 a binary comparison is passed as pairs.
299
300 """
301 stack: List[BinaryExpression[Any]] = []
302
303 def visit(element: ClauseElement) -> Iterator[ColumnElement[Any]]:
304 if isinstance(element, ScalarSelect):
305 # we don't want to dig into correlated subqueries,
306 # those are just column elements by themselves
307 yield element
308 elif element.__visit_name__ == "binary" and operators.is_comparison(
309 element.operator # type: ignore
310 ):
311 stack.insert(0, element) # type: ignore
312 for l in visit(element.left): # type: ignore
313 for r in visit(element.right): # type: ignore
314 fn(stack[0], l, r)
315 stack.pop(0)
316 for elem in element.get_children():
317 visit(elem)
318 else:
319 if isinstance(element, ColumnClause):
320 yield element
321 for elem in element.get_children():
322 yield from visit(elem)
323
324 list(visit(expr))
325 visit = None # type: ignore # remove gc cycles
326
327
328def find_tables(
329 clause: ClauseElement,
330 *,
331 check_columns: bool = False,
332 include_aliases: bool = False,
333 include_joins: bool = False,
334 include_selects: bool = False,
335 include_crud: bool = False,
336) -> List[TableClause]:
337 """locate Table objects within the given expression."""
338
339 tables: List[TableClause] = []
340 _visitors: Dict[str, _TraverseCallableType[Any]] = {}
341
342 if include_selects:
343 _visitors["select"] = _visitors["compound_select"] = tables.append
344
345 if include_joins:
346 _visitors["join"] = tables.append
347
348 if include_aliases:
349 _visitors["alias"] = _visitors["subquery"] = _visitors[
350 "tablesample"
351 ] = _visitors["lateral"] = tables.append
352
353 if include_crud:
354 _visitors["insert"] = _visitors["update"] = _visitors["delete"] = (
355 lambda ent: tables.append(ent.table)
356 )
357
358 if check_columns:
359
360 def visit_column(column):
361 tables.append(column.table)
362
363 _visitors["column"] = visit_column
364
365 _visitors["table"] = tables.append
366
367 visitors.traverse(clause, {}, _visitors)
368 return tables
369
370
371def unwrap_order_by(clause: Any) -> Any:
372 """Break up an 'order by' expression into individual column-expressions,
373 without DESC/ASC/NULLS FIRST/NULLS LAST"""
374
375 cols = util.column_set()
376 result = []
377 stack = deque([clause])
378
379 # examples
380 # column -> ASC/DESC == column
381 # column -> ASC/DESC -> label == column
382 # column -> label -> ASC/DESC -> label == column
383 # scalar_select -> label -> ASC/DESC == scalar_select -> label
384
385 while stack:
386 t = stack.popleft()
387 if isinstance(t, ColumnElement) and (
388 not isinstance(t, UnaryExpression)
389 or not operators.is_ordering_modifier(t.modifier) # type: ignore
390 ):
391 if isinstance(t, Label) and not isinstance(
392 t.element, ScalarSelect
393 ):
394 t = t.element
395
396 if isinstance(t, Grouping):
397 t = t.element
398
399 stack.append(t)
400 continue
401 elif isinstance(t, _label_reference):
402 t = t.element
403
404 stack.append(t)
405 continue
406 if isinstance(t, (_textual_label_reference)):
407 continue
408 if t not in cols:
409 cols.add(t)
410 result.append(t)
411
412 else:
413 for c in t.get_children():
414 stack.append(c)
415 return result
416
417
418def unwrap_label_reference(element):
419 def replace(
420 element: ExternallyTraversible, **kw: Any
421 ) -> Optional[ExternallyTraversible]:
422 if isinstance(element, _label_reference):
423 return element.element
424 elif isinstance(element, _textual_label_reference):
425 assert False, "can't unwrap a textual label reference"
426 return None
427
428 return visitors.replacement_traverse(element, {}, replace)
429
430
431def expand_column_list_from_order_by(collist, order_by):
432 """Given the columns clause and ORDER BY of a selectable,
433 return a list of column expressions that can be added to the collist
434 corresponding to the ORDER BY, without repeating those already
435 in the collist.
436
437 """
438 cols_already_present = {
439 col.element if col._order_by_label_element is not None else col
440 for col in collist
441 }
442
443 to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
444
445 return [col for col in to_look_for if col not in cols_already_present]
446
447
448def clause_is_present(clause, search):
449 """Given a target clause and a second to search within, return True
450 if the target is plainly present in the search without any
451 subqueries or aliases involved.
452
453 Basically descends through Joins.
454
455 """
456
457 for elem in surface_selectables(search):
458 if clause == elem: # use == here so that Annotated's compare
459 return True
460 else:
461 return False
462
463
464def tables_from_leftmost(clause: FromClause) -> Iterator[FromClause]:
465 if isinstance(clause, Join):
466 yield from tables_from_leftmost(clause.left)
467 yield from tables_from_leftmost(clause.right)
468 elif isinstance(clause, FromGrouping):
469 yield from tables_from_leftmost(clause.element)
470 else:
471 yield clause
472
473
474def surface_selectables(clause):
475 stack = [clause]
476 while stack:
477 elem = stack.pop()
478 yield elem
479 if isinstance(elem, Join):
480 stack.extend((elem.left, elem.right))
481 elif isinstance(elem, FromGrouping):
482 stack.append(elem.element)
483
484
485def surface_selectables_only(clause):
486 stack = [clause]
487 while stack:
488 elem = stack.pop()
489 if isinstance(elem, (TableClause, Alias)):
490 yield elem
491 if isinstance(elem, Join):
492 stack.extend((elem.left, elem.right))
493 elif isinstance(elem, FromGrouping):
494 stack.append(elem.element)
495 elif isinstance(elem, ColumnClause):
496 if elem.table is not None:
497 stack.append(elem.table)
498 else:
499 yield elem
500 elif elem is not None:
501 yield elem
502
503
504def extract_first_column_annotation(column, annotation_name):
505 filter_ = (FromGrouping, SelectBase)
506
507 stack = deque([column])
508 while stack:
509 elem = stack.popleft()
510 if annotation_name in elem._annotations:
511 return elem._annotations[annotation_name]
512 for sub in elem.get_children():
513 if isinstance(sub, filter_):
514 continue
515 stack.append(sub)
516 return None
517
518
519def selectables_overlap(left: FromClause, right: FromClause) -> bool:
520 """Return True if left/right have some overlapping selectable"""
521
522 return bool(
523 set(surface_selectables(left)).intersection(surface_selectables(right))
524 )
525
526
527def bind_values(clause):
528 """Return an ordered list of "bound" values in the given clause.
529
530 E.g.::
531
532 >>> expr = and_(
533 ... table.c.foo==5, table.c.foo==7
534 ... )
535 >>> bind_values(expr)
536 [5, 7]
537 """
538
539 v = []
540
541 def visit_bindparam(bind):
542 v.append(bind.effective_value)
543
544 visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
545 return v
546
547
548def _quote_ddl_expr(element):
549 if isinstance(element, str):
550 element = element.replace("'", "''")
551 return "'%s'" % element
552 else:
553 return repr(element)
554
555
556class _repr_base:
557 _LIST: int = 0
558 _TUPLE: int = 1
559 _DICT: int = 2
560
561 __slots__ = ("max_chars",)
562
563 max_chars: int
564
565 def trunc(self, value: Any) -> str:
566 rep = repr(value)
567 lenrep = len(rep)
568 if lenrep > self.max_chars:
569 segment_length = self.max_chars // 2
570 rep = (
571 rep[0:segment_length]
572 + (
573 " ... (%d characters truncated) ... "
574 % (lenrep - self.max_chars)
575 )
576 + rep[-segment_length:]
577 )
578 return rep
579
580
581def _repr_single_value(value):
582 rp = _repr_base()
583 rp.max_chars = 300
584 return rp.trunc(value)
585
586
587class _repr_row(_repr_base):
588 """Provide a string view of a row."""
589
590 __slots__ = ("row",)
591
592 def __init__(
593 self, row: Row[Unpack[Tuple[Any, ...]]], max_chars: int = 300
594 ):
595 self.row = row
596 self.max_chars = max_chars
597
598 def __repr__(self) -> str:
599 trunc = self.trunc
600 return "(%s%s)" % (
601 ", ".join(trunc(value) for value in self.row),
602 "," if len(self.row) == 1 else "",
603 )
604
605
606class _long_statement(str):
607 def __str__(self) -> str:
608 lself = len(self)
609 if lself > 500:
610 lleft = 250
611 lright = 100
612 trunc = lself - lleft - lright
613 return (
614 f"{self[0:lleft]} ... {trunc} "
615 f"characters truncated ... {self[-lright:]}"
616 )
617 else:
618 return str.__str__(self)
619
620
621class _repr_params(_repr_base):
622 """Provide a string view of bound parameters.
623
624 Truncates display to a given number of 'multi' parameter sets,
625 as well as long values to a given number of characters.
626
627 """
628
629 __slots__ = "params", "batches", "ismulti", "max_params"
630
631 def __init__(
632 self,
633 params: Optional[_AnyExecuteParams],
634 batches: int,
635 max_params: int = 100,
636 max_chars: int = 300,
637 ismulti: Optional[bool] = None,
638 ):
639 self.params = params
640 self.ismulti = ismulti
641 self.batches = batches
642 self.max_chars = max_chars
643 self.max_params = max_params
644
645 def __repr__(self) -> str:
646 if self.ismulti is None:
647 return self.trunc(self.params)
648
649 if isinstance(self.params, list):
650 typ = self._LIST
651
652 elif isinstance(self.params, tuple):
653 typ = self._TUPLE
654 elif isinstance(self.params, dict):
655 typ = self._DICT
656 else:
657 return self.trunc(self.params)
658
659 if self.ismulti:
660 multi_params = cast(
661 "_AnyMultiExecuteParams",
662 self.params,
663 )
664
665 if len(self.params) > self.batches:
666 msg = (
667 " ... displaying %i of %i total bound parameter sets ... "
668 )
669 return " ".join(
670 (
671 self._repr_multi(
672 multi_params[: self.batches - 2],
673 typ,
674 )[0:-1],
675 msg % (self.batches, len(self.params)),
676 self._repr_multi(multi_params[-2:], typ)[1:],
677 )
678 )
679 else:
680 return self._repr_multi(multi_params, typ)
681 else:
682 return self._repr_params(
683 cast(
684 "_AnySingleExecuteParams",
685 self.params,
686 ),
687 typ,
688 )
689
690 def _repr_multi(
691 self,
692 multi_params: _AnyMultiExecuteParams,
693 typ: int,
694 ) -> str:
695 if multi_params:
696 if isinstance(multi_params[0], list):
697 elem_type = self._LIST
698 elif isinstance(multi_params[0], tuple):
699 elem_type = self._TUPLE
700 elif isinstance(multi_params[0], dict):
701 elem_type = self._DICT
702 else:
703 assert False, "Unknown parameter type %s" % (
704 type(multi_params[0])
705 )
706
707 elements = ", ".join(
708 self._repr_params(params, elem_type) for params in multi_params
709 )
710 else:
711 elements = ""
712
713 if typ == self._LIST:
714 return "[%s]" % elements
715 else:
716 return "(%s)" % elements
717
718 def _get_batches(self, params: Iterable[Any]) -> Any:
719 lparams = list(params)
720 lenparams = len(lparams)
721 if lenparams > self.max_params:
722 lleft = self.max_params // 2
723 return (
724 lparams[0:lleft],
725 lparams[-lleft:],
726 lenparams - self.max_params,
727 )
728 else:
729 return lparams, None, None
730
731 def _repr_params(
732 self,
733 params: _AnySingleExecuteParams,
734 typ: int,
735 ) -> str:
736 if typ is self._DICT:
737 return self._repr_param_dict(
738 cast("_CoreSingleExecuteParams", params)
739 )
740 elif typ is self._TUPLE:
741 return self._repr_param_tuple(cast("Sequence[Any]", params))
742 else:
743 return self._repr_param_list(params)
744
745 def _repr_param_dict(self, params: _CoreSingleExecuteParams) -> str:
746 trunc = self.trunc
747 (
748 items_first_batch,
749 items_second_batch,
750 trunclen,
751 ) = self._get_batches(params.items())
752
753 if items_second_batch:
754 text = "{%s" % (
755 ", ".join(
756 f"{key!r}: {trunc(value)}"
757 for key, value in items_first_batch
758 )
759 )
760 text += f" ... {trunclen} parameters truncated ... "
761 text += "%s}" % (
762 ", ".join(
763 f"{key!r}: {trunc(value)}"
764 for key, value in items_second_batch
765 )
766 )
767 else:
768 text = "{%s}" % (
769 ", ".join(
770 f"{key!r}: {trunc(value)}"
771 for key, value in items_first_batch
772 )
773 )
774 return text
775
776 def _repr_param_tuple(self, params: Sequence[Any]) -> str:
777 trunc = self.trunc
778
779 (
780 items_first_batch,
781 items_second_batch,
782 trunclen,
783 ) = self._get_batches(params)
784
785 if items_second_batch:
786 text = "(%s" % (
787 ", ".join(trunc(value) for value in items_first_batch)
788 )
789 text += f" ... {trunclen} parameters truncated ... "
790 text += "%s)" % (
791 ", ".join(trunc(value) for value in items_second_batch),
792 )
793 else:
794 text = "(%s%s)" % (
795 ", ".join(trunc(value) for value in items_first_batch),
796 "," if len(items_first_batch) == 1 else "",
797 )
798 return text
799
800 def _repr_param_list(self, params: _AnySingleExecuteParams) -> str:
801 trunc = self.trunc
802 (
803 items_first_batch,
804 items_second_batch,
805 trunclen,
806 ) = self._get_batches(params)
807
808 if items_second_batch:
809 text = "[%s" % (
810 ", ".join(trunc(value) for value in items_first_batch)
811 )
812 text += f" ... {trunclen} parameters truncated ... "
813 text += "%s]" % (
814 ", ".join(trunc(value) for value in items_second_batch)
815 )
816 else:
817 text = "[%s]" % (
818 ", ".join(trunc(value) for value in items_first_batch)
819 )
820 return text
821
822
823def adapt_criterion_to_null(crit: _CE, nulls: Collection[Any]) -> _CE:
824 """given criterion containing bind params, convert selected elements
825 to IS NULL.
826
827 """
828
829 def visit_binary(binary):
830 if (
831 isinstance(binary.left, BindParameter)
832 and binary.left._identifying_key in nulls
833 ):
834 # reverse order if the NULL is on the left side
835 binary.left = binary.right
836 binary.right = Null()
837 binary.operator = operators.is_
838 binary.negate = operators.is_not
839 elif (
840 isinstance(binary.right, BindParameter)
841 and binary.right._identifying_key in nulls
842 ):
843 binary.right = Null()
844 binary.operator = operators.is_
845 binary.negate = operators.is_not
846
847 return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
848
849
850def splice_joins(
851 left: Optional[FromClause],
852 right: Optional[FromClause],
853 stop_on: Optional[FromClause] = None,
854) -> Optional[FromClause]:
855 if left is None:
856 return right
857
858 stack: List[Tuple[Optional[FromClause], Optional[Join]]] = [(right, None)]
859
860 adapter = ClauseAdapter(left)
861 ret = None
862 while stack:
863 (right, prevright) = stack.pop()
864 if isinstance(right, Join) and right is not stop_on:
865 right = right._clone()
866 right.onclause = adapter.traverse(right.onclause)
867 stack.append((right.left, right))
868 else:
869 right = adapter.traverse(right)
870 if prevright is not None:
871 assert right is not None
872 prevright.left = right
873 if ret is None:
874 ret = right
875
876 return ret
877
878
879@overload
880def reduce_columns(
881 columns: Iterable[ColumnElement[Any]],
882 *clauses: Optional[ClauseElement],
883 **kw: bool,
884) -> Sequence[ColumnElement[Any]]: ...
885
886
887@overload
888def reduce_columns(
889 columns: _SelectIterable,
890 *clauses: Optional[ClauseElement],
891 **kw: bool,
892) -> Sequence[Union[ColumnElement[Any], TextClause]]: ...
893
894
895def reduce_columns(
896 columns: _SelectIterable,
897 *clauses: Optional[ClauseElement],
898 **kw: bool,
899) -> Collection[Union[ColumnElement[Any], TextClause]]:
900 r"""given a list of columns, return a 'reduced' set based on natural
901 equivalents.
902
903 the set is reduced to the smallest list of columns which have no natural
904 equivalent present in the list. A "natural equivalent" means that two
905 columns will ultimately represent the same value because they are related
906 by a foreign key.
907
908 \*clauses is an optional list of join clauses which will be traversed
909 to further identify columns that are "equivalent".
910
911 \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
912 whose tables are not yet configured, or columns that aren't yet present.
913
914 This function is primarily used to determine the most minimal "primary
915 key" from a selectable, by reducing the set of primary key columns present
916 in the selectable to just those that are not repeated.
917
918 """
919 ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
920 only_synonyms = kw.pop("only_synonyms", False)
921
922 column_set = util.OrderedSet(columns)
923 cset_no_text: util.OrderedSet[ColumnElement[Any]] = column_set.difference(
924 c for c in column_set if is_text_clause(c) # type: ignore
925 )
926
927 omit = util.column_set()
928 for col in cset_no_text:
929 for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
930 for c in cset_no_text:
931 if c is col:
932 continue
933 try:
934 fk_col = fk.column
935 except exc.NoReferencedColumnError:
936 # TODO: add specific coverage here
937 # to test/sql/test_selectable ReduceTest
938 if ignore_nonexistent_tables:
939 continue
940 else:
941 raise
942 except exc.NoReferencedTableError:
943 # TODO: add specific coverage here
944 # to test/sql/test_selectable ReduceTest
945 if ignore_nonexistent_tables:
946 continue
947 else:
948 raise
949 if fk_col.shares_lineage(c) and (
950 not only_synonyms or c.name == col.name
951 ):
952 omit.add(col)
953 break
954
955 if clauses:
956
957 def visit_binary(binary):
958 if binary.operator == operators.eq:
959 cols = util.column_set(
960 chain(
961 *[c.proxy_set for c in cset_no_text.difference(omit)]
962 )
963 )
964 if binary.left in cols and binary.right in cols:
965 for c in reversed(cset_no_text):
966 if c.shares_lineage(binary.right) and (
967 not only_synonyms or c.name == binary.left.name
968 ):
969 omit.add(c)
970 break
971
972 for clause in clauses:
973 if clause is not None:
974 visitors.traverse(clause, {}, {"binary": visit_binary})
975
976 return column_set.difference(omit)
977
978
979def criterion_as_pairs(
980 expression,
981 consider_as_foreign_keys=None,
982 consider_as_referenced_keys=None,
983 any_operator=False,
984):
985 """traverse an expression and locate binary criterion pairs."""
986
987 if consider_as_foreign_keys and consider_as_referenced_keys:
988 raise exc.ArgumentError(
989 "Can only specify one of "
990 "'consider_as_foreign_keys' or "
991 "'consider_as_referenced_keys'"
992 )
993
994 def col_is(a, b):
995 # return a is b
996 return a.compare(b)
997
998 def visit_binary(binary):
999 if not any_operator and binary.operator is not operators.eq:
1000 return
1001 if not isinstance(binary.left, ColumnElement) or not isinstance(
1002 binary.right, ColumnElement
1003 ):
1004 return
1005
1006 if consider_as_foreign_keys:
1007 if binary.left in consider_as_foreign_keys and (
1008 col_is(binary.right, binary.left)
1009 or binary.right not in consider_as_foreign_keys
1010 ):
1011 pairs.append((binary.right, binary.left))
1012 elif binary.right in consider_as_foreign_keys and (
1013 col_is(binary.left, binary.right)
1014 or binary.left not in consider_as_foreign_keys
1015 ):
1016 pairs.append((binary.left, binary.right))
1017 elif consider_as_referenced_keys:
1018 if binary.left in consider_as_referenced_keys and (
1019 col_is(binary.right, binary.left)
1020 or binary.right not in consider_as_referenced_keys
1021 ):
1022 pairs.append((binary.left, binary.right))
1023 elif binary.right in consider_as_referenced_keys and (
1024 col_is(binary.left, binary.right)
1025 or binary.left not in consider_as_referenced_keys
1026 ):
1027 pairs.append((binary.right, binary.left))
1028 else:
1029 if isinstance(binary.left, Column) and isinstance(
1030 binary.right, Column
1031 ):
1032 if binary.left.references(binary.right):
1033 pairs.append((binary.right, binary.left))
1034 elif binary.right.references(binary.left):
1035 pairs.append((binary.left, binary.right))
1036
1037 pairs: List[Tuple[ColumnElement[Any], ColumnElement[Any]]] = []
1038 visitors.traverse(expression, {}, {"binary": visit_binary})
1039 return pairs
1040
1041
1042class ClauseAdapter(visitors.ReplacingExternalTraversal):
1043 """Clones and modifies clauses based on column correspondence.
1044
1045 E.g.::
1046
1047 table1 = Table('sometable', metadata,
1048 Column('col1', Integer),
1049 Column('col2', Integer)
1050 )
1051 table2 = Table('someothertable', metadata,
1052 Column('col1', Integer),
1053 Column('col2', Integer)
1054 )
1055
1056 condition = table1.c.col1 == table2.c.col1
1057
1058 make an alias of table1::
1059
1060 s = table1.alias('foo')
1061
1062 calling ``ClauseAdapter(s).traverse(condition)`` converts
1063 condition to read::
1064
1065 s.c.col1 == table2.c.col1
1066
1067 """
1068
1069 __slots__ = (
1070 "__traverse_options__",
1071 "selectable",
1072 "include_fn",
1073 "exclude_fn",
1074 "equivalents",
1075 "adapt_on_names",
1076 "adapt_from_selectables",
1077 )
1078
1079 def __init__(
1080 self,
1081 selectable: Selectable,
1082 equivalents: Optional[_EquivalentColumnMap] = None,
1083 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1084 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1085 adapt_on_names: bool = False,
1086 anonymize_labels: bool = False,
1087 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1088 ):
1089 self.__traverse_options__ = {
1090 "stop_on": [selectable],
1091 "anonymize_labels": anonymize_labels,
1092 }
1093 self.selectable = selectable
1094 self.include_fn = include_fn
1095 self.exclude_fn = exclude_fn
1096 self.equivalents = util.column_dict(equivalents or {})
1097 self.adapt_on_names = adapt_on_names
1098 self.adapt_from_selectables = adapt_from_selectables
1099
1100 if TYPE_CHECKING:
1101
1102 @overload
1103 def traverse(self, obj: Literal[None]) -> None: ...
1104
1105 # note this specializes the ReplacingExternalTraversal.traverse()
1106 # method to state
1107 # that we will return the same kind of ExternalTraversal object as
1108 # we were given. This is probably not 100% true, such as it's
1109 # possible for us to swap out Alias for Table at the top level.
1110 # Ideally there could be overloads specific to ColumnElement and
1111 # FromClause but Mypy is not accepting those as compatible with
1112 # the base ReplacingExternalTraversal
1113 @overload
1114 def traverse(self, obj: _ET) -> _ET: ...
1115
1116 def traverse(
1117 self, obj: Optional[ExternallyTraversible]
1118 ) -> Optional[ExternallyTraversible]: ...
1119
1120 def _corresponding_column(
1121 self, col, require_embedded, _seen=util.EMPTY_SET
1122 ):
1123 newcol = self.selectable.corresponding_column(
1124 col, require_embedded=require_embedded
1125 )
1126 if newcol is None and col in self.equivalents and col not in _seen:
1127 for equiv in self.equivalents[col]:
1128 newcol = self._corresponding_column(
1129 equiv,
1130 require_embedded=require_embedded,
1131 _seen=_seen.union([col]),
1132 )
1133 if newcol is not None:
1134 return newcol
1135
1136 if (
1137 self.adapt_on_names
1138 and newcol is None
1139 and isinstance(col, NamedColumn)
1140 ):
1141 newcol = self.selectable.exported_columns.get(col.name)
1142 return newcol
1143
1144 @util.preload_module("sqlalchemy.sql.functions")
1145 def replace(
1146 self, col: _ET, _include_singleton_constants: bool = False
1147 ) -> Optional[_ET]:
1148 functions = util.preloaded.sql_functions
1149
1150 # TODO: cython candidate
1151
1152 if self.include_fn and not self.include_fn(col): # type: ignore
1153 return None
1154 elif self.exclude_fn and self.exclude_fn(col): # type: ignore
1155 return None
1156
1157 if isinstance(col, FromClause) and not isinstance(
1158 col, functions.FunctionElement
1159 ):
1160 if self.selectable.is_derived_from(col):
1161 if self.adapt_from_selectables:
1162 for adp in self.adapt_from_selectables:
1163 if adp.is_derived_from(col):
1164 break
1165 else:
1166 return None
1167 return self.selectable # type: ignore
1168 elif isinstance(col, Alias) and isinstance(
1169 col.element, TableClause
1170 ):
1171 # we are a SELECT statement and not derived from an alias of a
1172 # table (which nonetheless may be a table our SELECT derives
1173 # from), so return the alias to prevent further traversal
1174 # or
1175 # we are an alias of a table and we are not derived from an
1176 # alias of a table (which nonetheless may be the same table
1177 # as ours) so, same thing
1178 return col # type: ignore
1179 else:
1180 # other cases where we are a selectable and the element
1181 # is another join or selectable that contains a table which our
1182 # selectable derives from, that we want to process
1183 return None
1184
1185 elif not isinstance(col, ColumnElement):
1186 return None
1187 elif not _include_singleton_constants and col._is_singleton_constant:
1188 # dont swap out NULL, TRUE, FALSE for a label name
1189 # in a SQL statement that's being rewritten,
1190 # leave them as the constant. This is first noted in #6259,
1191 # however the logic to check this moved here as of #7154 so that
1192 # it is made specific to SQL rewriting and not all column
1193 # correspondence
1194
1195 return None
1196
1197 if "adapt_column" in col._annotations:
1198 col = col._annotations["adapt_column"]
1199
1200 if TYPE_CHECKING:
1201 assert isinstance(col, KeyedColumnElement)
1202
1203 if self.adapt_from_selectables and col not in self.equivalents:
1204 for adp in self.adapt_from_selectables:
1205 if adp.c.corresponding_column(col, False) is not None:
1206 break
1207 else:
1208 return None
1209
1210 if TYPE_CHECKING:
1211 assert isinstance(col, KeyedColumnElement)
1212
1213 return self._corresponding_column( # type: ignore
1214 col, require_embedded=True
1215 )
1216
1217
1218class _ColumnLookup(Protocol):
1219 @overload
1220 def __getitem__(self, key: None) -> None: ...
1221
1222 @overload
1223 def __getitem__(self, key: ColumnClause[Any]) -> ColumnClause[Any]: ...
1224
1225 @overload
1226 def __getitem__(self, key: ColumnElement[Any]) -> ColumnElement[Any]: ...
1227
1228 @overload
1229 def __getitem__(self, key: _ET) -> _ET: ...
1230
1231 def __getitem__(self, key: Any) -> Any: ...
1232
1233
1234class ColumnAdapter(ClauseAdapter):
1235 """Extends ClauseAdapter with extra utility functions.
1236
1237 Key aspects of ColumnAdapter include:
1238
1239 * Expressions that are adapted are stored in a persistent
1240 .columns collection; so that an expression E adapted into
1241 an expression E1, will return the same object E1 when adapted
1242 a second time. This is important in particular for things like
1243 Label objects that are anonymized, so that the ColumnAdapter can
1244 be used to present a consistent "adapted" view of things.
1245
1246 * Exclusion of items from the persistent collection based on
1247 include/exclude rules, but also independent of hash identity.
1248 This because "annotated" items all have the same hash identity as their
1249 parent.
1250
1251 * "wrapping" capability is added, so that the replacement of an expression
1252 E can proceed through a series of adapters. This differs from the
1253 visitor's "chaining" feature in that the resulting object is passed
1254 through all replacing functions unconditionally, rather than stopping
1255 at the first one that returns non-None.
1256
1257 * An adapt_required option, used by eager loading to indicate that
1258 We don't trust a result row column that is not translated.
1259 This is to prevent a column from being interpreted as that
1260 of the child row in a self-referential scenario, see
1261 inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
1262
1263 """
1264
1265 __slots__ = (
1266 "columns",
1267 "adapt_required",
1268 "allow_label_resolve",
1269 "_wrap",
1270 "__weakref__",
1271 )
1272
1273 columns: _ColumnLookup
1274
1275 def __init__(
1276 self,
1277 selectable: Selectable,
1278 equivalents: Optional[_EquivalentColumnMap] = None,
1279 adapt_required: bool = False,
1280 include_fn: Optional[Callable[[ClauseElement], bool]] = None,
1281 exclude_fn: Optional[Callable[[ClauseElement], bool]] = None,
1282 adapt_on_names: bool = False,
1283 allow_label_resolve: bool = True,
1284 anonymize_labels: bool = False,
1285 adapt_from_selectables: Optional[AbstractSet[FromClause]] = None,
1286 ):
1287 super().__init__(
1288 selectable,
1289 equivalents,
1290 include_fn=include_fn,
1291 exclude_fn=exclude_fn,
1292 adapt_on_names=adapt_on_names,
1293 anonymize_labels=anonymize_labels,
1294 adapt_from_selectables=adapt_from_selectables,
1295 )
1296
1297 self.columns = util.WeakPopulateDict(self._locate_col) # type: ignore
1298 if self.include_fn or self.exclude_fn:
1299 self.columns = self._IncludeExcludeMapping(self, self.columns)
1300 self.adapt_required = adapt_required
1301 self.allow_label_resolve = allow_label_resolve
1302 self._wrap = None
1303
1304 class _IncludeExcludeMapping:
1305 def __init__(self, parent, columns):
1306 self.parent = parent
1307 self.columns = columns
1308
1309 def __getitem__(self, key):
1310 if (
1311 self.parent.include_fn and not self.parent.include_fn(key)
1312 ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
1313 if self.parent._wrap:
1314 return self.parent._wrap.columns[key]
1315 else:
1316 return key
1317 return self.columns[key]
1318
1319 def wrap(self, adapter):
1320 ac = copy.copy(self)
1321 ac._wrap = adapter
1322 ac.columns = util.WeakPopulateDict(ac._locate_col) # type: ignore
1323 if ac.include_fn or ac.exclude_fn:
1324 ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
1325
1326 return ac
1327
1328 @overload
1329 def traverse(self, obj: Literal[None]) -> None: ...
1330
1331 @overload
1332 def traverse(self, obj: _ET) -> _ET: ...
1333
1334 def traverse(
1335 self, obj: Optional[ExternallyTraversible]
1336 ) -> Optional[ExternallyTraversible]:
1337 return self.columns[obj]
1338
1339 def chain(self, visitor: ExternalTraversal) -> ColumnAdapter:
1340 assert isinstance(visitor, ColumnAdapter)
1341
1342 return super().chain(visitor)
1343
1344 if TYPE_CHECKING:
1345
1346 @property
1347 def visitor_iterator(self) -> Iterator[ColumnAdapter]: ...
1348
1349 adapt_clause = traverse
1350 adapt_list = ClauseAdapter.copy_and_process
1351
1352 def adapt_check_present(
1353 self, col: ColumnElement[Any]
1354 ) -> Optional[ColumnElement[Any]]:
1355 newcol = self.columns[col]
1356
1357 if newcol is col and self._corresponding_column(col, True) is None:
1358 return None
1359
1360 return newcol
1361
1362 def _locate_col(
1363 self, col: ColumnElement[Any]
1364 ) -> Optional[ColumnElement[Any]]:
1365 # both replace and traverse() are overly complicated for what
1366 # we are doing here and we would do better to have an inlined
1367 # version that doesn't build up as much overhead. the issue is that
1368 # sometimes the lookup does in fact have to adapt the insides of
1369 # say a labeled scalar subquery. However, if the object is an
1370 # Immutable, i.e. Column objects, we can skip the "clone" /
1371 # "copy internals" part since those will be no-ops in any case.
1372 # additionally we want to catch singleton objects null/true/false
1373 # and make sure they are adapted as well here.
1374
1375 if col._is_immutable:
1376 for vis in self.visitor_iterator:
1377 c = vis.replace(col, _include_singleton_constants=True)
1378 if c is not None:
1379 break
1380 else:
1381 c = col
1382 else:
1383 c = ClauseAdapter.traverse(self, col)
1384
1385 if self._wrap:
1386 c2 = self._wrap._locate_col(c)
1387 if c2 is not None:
1388 c = c2
1389
1390 if self.adapt_required and c is col:
1391 return None
1392
1393 # allow_label_resolve is consumed by one case for joined eager loading
1394 # as part of its logic to prevent its own columns from being affected
1395 # by .order_by(). Before full typing were applied to the ORM, this
1396 # logic would set this attribute on the incoming object (which is
1397 # typically a column, but we have a test for it being a non-column
1398 # object) if no column were found. While this seemed to
1399 # have no negative effects, this adjustment should only occur on the
1400 # new column which is assumed to be local to an adapted selectable.
1401 if c is not col:
1402 c._allow_label_resolve = self.allow_label_resolve
1403
1404 return c
1405
1406
1407def _offset_or_limit_clause(
1408 element: _LimitOffsetType,
1409 name: Optional[str] = None,
1410 type_: Optional[_TypeEngineArgument[int]] = None,
1411) -> ColumnElement[int]:
1412 """Convert the given value to an "offset or limit" clause.
1413
1414 This handles incoming integers and converts to an expression; if
1415 an expression is already given, it is passed through.
1416
1417 """
1418 return coercions.expect(
1419 roles.LimitOffsetRole, element, name=name, type_=type_
1420 )
1421
1422
1423def _offset_or_limit_clause_asint_if_possible(
1424 clause: _LimitOffsetType,
1425) -> _LimitOffsetType:
1426 """Return the offset or limit clause as a simple integer if possible,
1427 else return the clause.
1428
1429 """
1430 if clause is None:
1431 return None
1432 if hasattr(clause, "_limit_offset_value"):
1433 value = clause._limit_offset_value
1434 return util.asint(value)
1435 else:
1436 return clause
1437
1438
1439def _make_slice(
1440 limit_clause: _LimitOffsetType,
1441 offset_clause: _LimitOffsetType,
1442 start: int,
1443 stop: int,
1444) -> Tuple[Optional[ColumnElement[int]], Optional[ColumnElement[int]]]:
1445 """Compute LIMIT/OFFSET in terms of slice start/end"""
1446
1447 # for calculated limit/offset, try to do the addition of
1448 # values to offset in Python, however if a SQL clause is present
1449 # then the addition has to be on the SQL side.
1450
1451 # TODO: typing is finding a few gaps in here, see if they can be
1452 # closed up
1453
1454 if start is not None and stop is not None:
1455 offset_clause = _offset_or_limit_clause_asint_if_possible(
1456 offset_clause
1457 )
1458 if offset_clause is None:
1459 offset_clause = 0
1460
1461 if start != 0:
1462 offset_clause = offset_clause + start # type: ignore
1463
1464 if offset_clause == 0:
1465 offset_clause = None
1466 else:
1467 assert offset_clause is not None
1468 offset_clause = _offset_or_limit_clause(offset_clause)
1469
1470 limit_clause = _offset_or_limit_clause(stop - start)
1471
1472 elif start is None and stop is not None:
1473 limit_clause = _offset_or_limit_clause(stop)
1474 elif start is not None and stop is None:
1475 offset_clause = _offset_or_limit_clause_asint_if_possible(
1476 offset_clause
1477 )
1478 if offset_clause is None:
1479 offset_clause = 0
1480
1481 if start != 0:
1482 offset_clause = offset_clause + start
1483
1484 if offset_clause == 0:
1485 offset_clause = None
1486 else:
1487 offset_clause = _offset_or_limit_clause(offset_clause)
1488
1489 return limit_clause, offset_clause