1# sql/util.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""High level utilities which build upon other modules here.
9
10"""
11
12from collections import deque
13from itertools import chain
14
15from . import coercions
16from . import operators
17from . import roles
18from . import visitors
19from .annotation import _deep_annotate # noqa
20from .annotation import _deep_deannotate # noqa
21from .annotation import _shallow_annotate # noqa
22from .base import _expand_cloned
23from .base import _from_objects
24from .base import ColumnSet
25from .ddl import sort_tables # noqa
26from .elements import _find_columns # noqa
27from .elements import _label_reference
28from .elements import _textual_label_reference
29from .elements import BindParameter
30from .elements import ColumnClause
31from .elements import ColumnElement
32from .elements import Grouping
33from .elements import Label
34from .elements import Null
35from .elements import UnaryExpression
36from .schema import Column
37from .selectable import Alias
38from .selectable import FromClause
39from .selectable import FromGrouping
40from .selectable import Join
41from .selectable import ScalarSelect
42from .selectable import SelectBase
43from .selectable import TableClause
44from .traversals import HasCacheKey # noqa
45from .. import exc
46from .. import util
47
48
49join_condition = util.langhelpers.public_factory(
50 Join._join_condition, ".sql.util.join_condition"
51)
52
53
54def find_join_source(clauses, join_to):
55 """Given a list of FROM clauses and a selectable,
56 return the first index and element from the list of
57 clauses which can be joined against the selectable. returns
58 None, None if no match is found.
59
60 e.g.::
61
62 clause1 = table1.join(table2)
63 clause2 = table4.join(table5)
64
65 join_to = table2.join(table3)
66
67 find_join_source([clause1, clause2], join_to) == clause1
68
69 """
70
71 selectables = list(_from_objects(join_to))
72 idx = []
73 for i, f in enumerate(clauses):
74 for s in selectables:
75 if f.is_derived_from(s):
76 idx.append(i)
77 return idx
78
79
80def find_left_clause_that_matches_given(clauses, join_from):
81 """Given a list of FROM clauses and a selectable,
82 return the indexes from the list of
83 clauses which is derived from the selectable.
84
85 """
86
87 selectables = list(_from_objects(join_from))
88 liberal_idx = []
89 for i, f in enumerate(clauses):
90 for s in selectables:
91 # basic check, if f is derived from s.
92 # this can be joins containing a table, or an aliased table
93 # or select statement matching to a table. This check
94 # will match a table to a selectable that is adapted from
95 # that table. With Query, this suits the case where a join
96 # is being made to an adapted entity
97 if f.is_derived_from(s):
98 liberal_idx.append(i)
99 break
100
101 # in an extremely small set of use cases, a join is being made where
102 # there are multiple FROM clauses where our target table is represented
103 # in more than one, such as embedded or similar. in this case, do
104 # another pass where we try to get a more exact match where we aren't
105 # looking at adaption relationships.
106 if len(liberal_idx) > 1:
107 conservative_idx = []
108 for idx in liberal_idx:
109 f = clauses[idx]
110 for s in selectables:
111 if set(surface_selectables(f)).intersection(
112 surface_selectables(s)
113 ):
114 conservative_idx.append(idx)
115 break
116 if conservative_idx:
117 return conservative_idx
118
119 return liberal_idx
120
121
122def find_left_clause_to_join_from(clauses, join_to, onclause):
123 """Given a list of FROM clauses, a selectable,
124 and optional ON clause, return a list of integer indexes from the
125 clauses list indicating the clauses that can be joined from.
126
127 The presence of an "onclause" indicates that at least one clause can
128 definitely be joined from; if the list of clauses is of length one
129 and the onclause is given, returns that index. If the list of clauses
130 is more than length one, and the onclause is given, attempts to locate
131 which clauses contain the same columns.
132
133 """
134 idx = []
135 selectables = set(_from_objects(join_to))
136
137 # if we are given more than one target clause to join
138 # from, use the onclause to provide a more specific answer.
139 # otherwise, don't try to limit, after all, "ON TRUE" is a valid
140 # on clause
141 if len(clauses) > 1 and onclause is not None:
142 resolve_ambiguity = True
143 cols_in_onclause = _find_columns(onclause)
144 else:
145 resolve_ambiguity = False
146 cols_in_onclause = None
147
148 for i, f in enumerate(clauses):
149 for s in selectables.difference([f]):
150 if resolve_ambiguity:
151 if set(f.c).union(s.c).issuperset(cols_in_onclause):
152 idx.append(i)
153 break
154 elif onclause is not None or Join._can_join(f, s):
155 idx.append(i)
156 break
157
158 if len(idx) > 1:
159 # this is the same "hide froms" logic from
160 # Selectable._get_display_froms
161 toremove = set(
162 chain(*[_expand_cloned(f._hide_froms) for f in clauses])
163 )
164 idx = [i for i in idx if clauses[i] not in toremove]
165
166 # onclause was given and none of them resolved, so assume
167 # all indexes can match
168 if not idx and onclause is not None:
169 return range(len(clauses))
170 else:
171 return idx
172
173
174def visit_binary_product(fn, expr):
175 """Produce a traversal of the given expression, delivering
176 column comparisons to the given function.
177
178 The function is of the form::
179
180 def my_fn(binary, left, right)
181
182 For each binary expression located which has a
183 comparison operator, the product of "left" and
184 "right" will be delivered to that function,
185 in terms of that binary.
186
187 Hence an expression like::
188
189 and_(
190 (a + b) == q + func.sum(e + f),
191 j == r
192 )
193
194 would have the traversal::
195
196 a <eq> q
197 a <eq> e
198 a <eq> f
199 b <eq> q
200 b <eq> e
201 b <eq> f
202 j <eq> r
203
204 That is, every combination of "left" and
205 "right" that doesn't further contain
206 a binary comparison is passed as pairs.
207
208 """
209 stack = []
210
211 def visit(element):
212 if isinstance(element, ScalarSelect):
213 # we don't want to dig into correlated subqueries,
214 # those are just column elements by themselves
215 yield element
216 elif element.__visit_name__ == "binary" and operators.is_comparison(
217 element.operator
218 ):
219 stack.insert(0, element)
220 for l in visit(element.left):
221 for r in visit(element.right):
222 fn(stack[0], l, r)
223 stack.pop(0)
224 for elem in element.get_children():
225 visit(elem)
226 else:
227 if isinstance(element, ColumnClause):
228 yield element
229 for elem in element.get_children():
230 for e in visit(elem):
231 yield e
232
233 list(visit(expr))
234 visit = None # remove gc cycles
235
236
237def find_tables(
238 clause,
239 check_columns=False,
240 include_aliases=False,
241 include_joins=False,
242 include_selects=False,
243 include_crud=False,
244):
245 """locate Table objects within the given expression."""
246
247 tables = []
248 _visitors = {}
249
250 if include_selects:
251 _visitors["select"] = _visitors["compound_select"] = tables.append
252
253 if include_joins:
254 _visitors["join"] = tables.append
255
256 if include_aliases:
257 _visitors["alias"] = _visitors["subquery"] = _visitors[
258 "tablesample"
259 ] = _visitors["lateral"] = tables.append
260
261 if include_crud:
262 _visitors["insert"] = _visitors["update"] = _visitors[
263 "delete"
264 ] = lambda ent: tables.append(ent.table)
265
266 if check_columns:
267
268 def visit_column(column):
269 tables.append(column.table)
270
271 _visitors["column"] = visit_column
272
273 _visitors["table"] = tables.append
274
275 visitors.traverse(clause, {}, _visitors)
276 return tables
277
278
279def unwrap_order_by(clause):
280 """Break up an 'order by' expression into individual column-expressions,
281 without DESC/ASC/NULLS FIRST/NULLS LAST"""
282
283 cols = util.column_set()
284 result = []
285 stack = deque([clause])
286
287 # examples
288 # column -> ASC/DESC == column
289 # column -> ASC/DESC -> label == column
290 # column -> label -> ASC/DESC -> label == column
291 # scalar_select -> label -> ASC/DESC == scalar_select -> label
292
293 while stack:
294 t = stack.popleft()
295 if isinstance(t, ColumnElement) and (
296 not isinstance(t, UnaryExpression)
297 or not operators.is_ordering_modifier(t.modifier)
298 ):
299 if isinstance(t, Label) and not isinstance(
300 t.element, ScalarSelect
301 ):
302 t = t.element
303
304 if isinstance(t, Grouping):
305 t = t.element
306
307 stack.append(t)
308 continue
309 elif isinstance(t, _label_reference):
310 t = t.element
311
312 stack.append(t)
313 continue
314 if isinstance(t, (_textual_label_reference)):
315 continue
316 if t not in cols:
317 cols.add(t)
318 result.append(t)
319
320 else:
321 for c in t.get_children():
322 stack.append(c)
323 return result
324
325
326def unwrap_label_reference(element):
327 def replace(elem):
328 if isinstance(elem, (_label_reference, _textual_label_reference)):
329 return elem.element
330
331 return visitors.replacement_traverse(element, {}, replace)
332
333
334def expand_column_list_from_order_by(collist, order_by):
335 """Given the columns clause and ORDER BY of a selectable,
336 return a list of column expressions that can be added to the collist
337 corresponding to the ORDER BY, without repeating those already
338 in the collist.
339
340 """
341 cols_already_present = set(
342 [
343 col.element if col._order_by_label_element is not None else col
344 for col in collist
345 ]
346 )
347
348 to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
349
350 return [col for col in to_look_for if col not in cols_already_present]
351
352
353def clause_is_present(clause, search):
354 """Given a target clause and a second to search within, return True
355 if the target is plainly present in the search without any
356 subqueries or aliases involved.
357
358 Basically descends through Joins.
359
360 """
361
362 for elem in surface_selectables(search):
363 if clause == elem: # use == here so that Annotated's compare
364 return True
365 else:
366 return False
367
368
369def tables_from_leftmost(clause):
370 if isinstance(clause, Join):
371 for t in tables_from_leftmost(clause.left):
372 yield t
373 for t in tables_from_leftmost(clause.right):
374 yield t
375 elif isinstance(clause, FromGrouping):
376 for t in tables_from_leftmost(clause.element):
377 yield t
378 else:
379 yield clause
380
381
382def surface_selectables(clause):
383 stack = [clause]
384 while stack:
385 elem = stack.pop()
386 yield elem
387 if isinstance(elem, Join):
388 stack.extend((elem.left, elem.right))
389 elif isinstance(elem, FromGrouping):
390 stack.append(elem.element)
391
392
393def surface_selectables_only(clause):
394 stack = [clause]
395 while stack:
396 elem = stack.pop()
397 if isinstance(elem, (TableClause, Alias)):
398 yield elem
399 if isinstance(elem, Join):
400 stack.extend((elem.left, elem.right))
401 elif isinstance(elem, FromGrouping):
402 stack.append(elem.element)
403 elif isinstance(elem, ColumnClause):
404 if elem.table is not None:
405 stack.append(elem.table)
406 else:
407 yield elem
408 elif elem is not None:
409 yield elem
410
411
412def extract_first_column_annotation(column, annotation_name):
413 filter_ = (FromGrouping, SelectBase)
414
415 stack = deque([column])
416 while stack:
417 elem = stack.popleft()
418 if annotation_name in elem._annotations:
419 return elem._annotations[annotation_name]
420 for sub in elem.get_children():
421 if isinstance(sub, filter_):
422 continue
423 stack.append(sub)
424 return None
425
426
427def selectables_overlap(left, right):
428 """Return True if left/right have some overlapping selectable"""
429
430 return bool(
431 set(surface_selectables(left)).intersection(surface_selectables(right))
432 )
433
434
435def bind_values(clause):
436 """Return an ordered list of "bound" values in the given clause.
437
438 E.g.::
439
440 >>> expr = and_(
441 ... table.c.foo==5, table.c.foo==7
442 ... )
443 >>> bind_values(expr)
444 [5, 7]
445 """
446
447 v = []
448
449 def visit_bindparam(bind):
450 v.append(bind.effective_value)
451
452 visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
453 return v
454
455
456def _quote_ddl_expr(element):
457 if isinstance(element, util.string_types):
458 element = element.replace("'", "''")
459 return "'%s'" % element
460 else:
461 return repr(element)
462
463
464class _repr_base(object):
465 _LIST = 0
466 _TUPLE = 1
467 _DICT = 2
468
469 __slots__ = ("max_chars",)
470
471 def trunc(self, value):
472 rep = repr(value)
473 lenrep = len(rep)
474 if lenrep > self.max_chars:
475 segment_length = self.max_chars // 2
476 rep = (
477 rep[0:segment_length]
478 + (
479 " ... (%d characters truncated) ... "
480 % (lenrep - self.max_chars)
481 )
482 + rep[-segment_length:]
483 )
484 return rep
485
486
487def _repr_single_value(value):
488 rp = _repr_base()
489 rp.max_chars = 300
490 return rp.trunc(value)
491
492
493class _repr_row(_repr_base):
494 """Provide a string view of a row."""
495
496 __slots__ = ("row",)
497
498 def __init__(self, row, max_chars=300):
499 self.row = row
500 self.max_chars = max_chars
501
502 def __repr__(self):
503 trunc = self.trunc
504 return "(%s%s)" % (
505 ", ".join(trunc(value) for value in self.row),
506 "," if len(self.row) == 1 else "",
507 )
508
509
510class _repr_params(_repr_base):
511 """Provide a string view of bound parameters.
512
513 Truncates display to a given number of 'multi' parameter sets,
514 as well as long values to a given number of characters.
515
516 """
517
518 __slots__ = "params", "batches", "ismulti"
519
520 def __init__(self, params, batches, max_chars=300, ismulti=None):
521 self.params = params
522 self.ismulti = ismulti
523 self.batches = batches
524 self.max_chars = max_chars
525
526 def __repr__(self):
527 if self.ismulti is None:
528 return self.trunc(self.params)
529
530 if isinstance(self.params, list):
531 typ = self._LIST
532
533 elif isinstance(self.params, tuple):
534 typ = self._TUPLE
535 elif isinstance(self.params, dict):
536 typ = self._DICT
537 else:
538 return self.trunc(self.params)
539
540 if self.ismulti and len(self.params) > self.batches:
541 msg = " ... displaying %i of %i total bound parameter sets ... "
542 return " ".join(
543 (
544 self._repr_multi(self.params[: self.batches - 2], typ)[
545 0:-1
546 ],
547 msg % (self.batches, len(self.params)),
548 self._repr_multi(self.params[-2:], typ)[1:],
549 )
550 )
551 elif self.ismulti:
552 return self._repr_multi(self.params, typ)
553 else:
554 return self._repr_params(self.params, typ)
555
556 def _repr_multi(self, multi_params, typ):
557 if multi_params:
558 if isinstance(multi_params[0], list):
559 elem_type = self._LIST
560 elif isinstance(multi_params[0], tuple):
561 elem_type = self._TUPLE
562 elif isinstance(multi_params[0], dict):
563 elem_type = self._DICT
564 else:
565 assert False, "Unknown parameter type %s" % (
566 type(multi_params[0])
567 )
568
569 elements = ", ".join(
570 self._repr_params(params, elem_type) for params in multi_params
571 )
572 else:
573 elements = ""
574
575 if typ == self._LIST:
576 return "[%s]" % elements
577 else:
578 return "(%s)" % elements
579
580 def _repr_params(self, params, typ):
581 trunc = self.trunc
582 if typ is self._DICT:
583 return "{%s}" % (
584 ", ".join(
585 "%r: %s" % (key, trunc(value))
586 for key, value in params.items()
587 )
588 )
589 elif typ is self._TUPLE:
590 return "(%s%s)" % (
591 ", ".join(trunc(value) for value in params),
592 "," if len(params) == 1 else "",
593 )
594 else:
595 return "[%s]" % (", ".join(trunc(value) for value in params))
596
597
598def adapt_criterion_to_null(crit, nulls):
599 """given criterion containing bind params, convert selected elements
600 to IS NULL.
601
602 """
603
604 def visit_binary(binary):
605 if (
606 isinstance(binary.left, BindParameter)
607 and binary.left._identifying_key in nulls
608 ):
609 # reverse order if the NULL is on the left side
610 binary.left = binary.right
611 binary.right = Null()
612 binary.operator = operators.is_
613 binary.negate = operators.is_not
614 elif (
615 isinstance(binary.right, BindParameter)
616 and binary.right._identifying_key in nulls
617 ):
618 binary.right = Null()
619 binary.operator = operators.is_
620 binary.negate = operators.is_not
621
622 return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
623
624
625def splice_joins(left, right, stop_on=None):
626 if left is None:
627 return right
628
629 stack = [(right, None)]
630
631 adapter = ClauseAdapter(left)
632 ret = None
633 while stack:
634 (right, prevright) = stack.pop()
635 if isinstance(right, Join) and right is not stop_on:
636 right = right._clone()
637 right.onclause = adapter.traverse(right.onclause)
638 stack.append((right.left, right))
639 else:
640 right = adapter.traverse(right)
641 if prevright is not None:
642 prevright.left = right
643 if ret is None:
644 ret = right
645
646 return ret
647
648
649def reduce_columns(columns, *clauses, **kw):
650 r"""given a list of columns, return a 'reduced' set based on natural
651 equivalents.
652
653 the set is reduced to the smallest list of columns which have no natural
654 equivalent present in the list. A "natural equivalent" means that two
655 columns will ultimately represent the same value because they are related
656 by a foreign key.
657
658 \*clauses is an optional list of join clauses which will be traversed
659 to further identify columns that are "equivalent".
660
661 \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
662 whose tables are not yet configured, or columns that aren't yet present.
663
664 This function is primarily used to determine the most minimal "primary
665 key" from a selectable, by reducing the set of primary key columns present
666 in the selectable to just those that are not repeated.
667
668 """
669 ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
670 only_synonyms = kw.pop("only_synonyms", False)
671
672 columns = util.ordered_column_set(columns)
673
674 omit = util.column_set()
675 for col in columns:
676 for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
677 for c in columns:
678 if c is col:
679 continue
680 try:
681 fk_col = fk.column
682 except exc.NoReferencedColumnError:
683 # TODO: add specific coverage here
684 # to test/sql/test_selectable ReduceTest
685 if ignore_nonexistent_tables:
686 continue
687 else:
688 raise
689 except exc.NoReferencedTableError:
690 # TODO: add specific coverage here
691 # to test/sql/test_selectable ReduceTest
692 if ignore_nonexistent_tables:
693 continue
694 else:
695 raise
696 if fk_col.shares_lineage(c) and (
697 not only_synonyms or c.name == col.name
698 ):
699 omit.add(col)
700 break
701
702 if clauses:
703
704 def visit_binary(binary):
705 if binary.operator == operators.eq:
706 cols = util.column_set(
707 chain(*[c.proxy_set for c in columns.difference(omit)])
708 )
709 if binary.left in cols and binary.right in cols:
710 for c in reversed(columns):
711 if c.shares_lineage(binary.right) and (
712 not only_synonyms or c.name == binary.left.name
713 ):
714 omit.add(c)
715 break
716
717 for clause in clauses:
718 if clause is not None:
719 visitors.traverse(clause, {}, {"binary": visit_binary})
720
721 return ColumnSet(columns.difference(omit))
722
723
724def criterion_as_pairs(
725 expression,
726 consider_as_foreign_keys=None,
727 consider_as_referenced_keys=None,
728 any_operator=False,
729):
730 """traverse an expression and locate binary criterion pairs."""
731
732 if consider_as_foreign_keys and consider_as_referenced_keys:
733 raise exc.ArgumentError(
734 "Can only specify one of "
735 "'consider_as_foreign_keys' or "
736 "'consider_as_referenced_keys'"
737 )
738
739 def col_is(a, b):
740 # return a is b
741 return a.compare(b)
742
743 def visit_binary(binary):
744 if not any_operator and binary.operator is not operators.eq:
745 return
746 if not isinstance(binary.left, ColumnElement) or not isinstance(
747 binary.right, ColumnElement
748 ):
749 return
750
751 if consider_as_foreign_keys:
752 if binary.left in consider_as_foreign_keys and (
753 col_is(binary.right, binary.left)
754 or binary.right not in consider_as_foreign_keys
755 ):
756 pairs.append((binary.right, binary.left))
757 elif binary.right in consider_as_foreign_keys and (
758 col_is(binary.left, binary.right)
759 or binary.left not in consider_as_foreign_keys
760 ):
761 pairs.append((binary.left, binary.right))
762 elif consider_as_referenced_keys:
763 if binary.left in consider_as_referenced_keys and (
764 col_is(binary.right, binary.left)
765 or binary.right not in consider_as_referenced_keys
766 ):
767 pairs.append((binary.left, binary.right))
768 elif binary.right in consider_as_referenced_keys and (
769 col_is(binary.left, binary.right)
770 or binary.left not in consider_as_referenced_keys
771 ):
772 pairs.append((binary.right, binary.left))
773 else:
774 if isinstance(binary.left, Column) and isinstance(
775 binary.right, Column
776 ):
777 if binary.left.references(binary.right):
778 pairs.append((binary.right, binary.left))
779 elif binary.right.references(binary.left):
780 pairs.append((binary.left, binary.right))
781
782 pairs = []
783 visitors.traverse(expression, {}, {"binary": visit_binary})
784 return pairs
785
786
787class ClauseAdapter(visitors.ReplacingExternalTraversal):
788 """Clones and modifies clauses based on column correspondence.
789
790 E.g.::
791
792 table1 = Table('sometable', metadata,
793 Column('col1', Integer),
794 Column('col2', Integer)
795 )
796 table2 = Table('someothertable', metadata,
797 Column('col1', Integer),
798 Column('col2', Integer)
799 )
800
801 condition = table1.c.col1 == table2.c.col1
802
803 make an alias of table1::
804
805 s = table1.alias('foo')
806
807 calling ``ClauseAdapter(s).traverse(condition)`` converts
808 condition to read::
809
810 s.c.col1 == table2.c.col1
811
812 """
813
814 def __init__(
815 self,
816 selectable,
817 equivalents=None,
818 include_fn=None,
819 exclude_fn=None,
820 adapt_on_names=False,
821 anonymize_labels=False,
822 adapt_from_selectables=None,
823 ):
824 self.__traverse_options__ = {
825 "stop_on": [selectable],
826 "anonymize_labels": anonymize_labels,
827 }
828 self.selectable = selectable
829 self.include_fn = include_fn
830 self.exclude_fn = exclude_fn
831 self.equivalents = util.column_dict(equivalents or {})
832 self.adapt_on_names = adapt_on_names
833 self.adapt_from_selectables = adapt_from_selectables
834
835 def _corresponding_column(
836 self, col, require_embedded, _seen=util.EMPTY_SET
837 ):
838
839 newcol = self.selectable.corresponding_column(
840 col, require_embedded=require_embedded
841 )
842 if newcol is None and col in self.equivalents and col not in _seen:
843 for equiv in self.equivalents[col]:
844 newcol = self._corresponding_column(
845 equiv,
846 require_embedded=require_embedded,
847 _seen=_seen.union([col]),
848 )
849 if newcol is not None:
850 return newcol
851 if self.adapt_on_names and newcol is None:
852 newcol = self.selectable.exported_columns.get(col.name)
853 return newcol
854
855 @util.preload_module("sqlalchemy.sql.functions")
856 def replace(self, col, _include_singleton_constants=False):
857 functions = util.preloaded.sql_functions
858
859 if isinstance(col, FromClause) and not isinstance(
860 col, functions.FunctionElement
861 ):
862
863 if self.selectable.is_derived_from(col):
864 if self.adapt_from_selectables:
865 for adp in self.adapt_from_selectables:
866 if adp.is_derived_from(col):
867 break
868 else:
869 return None
870 return self.selectable
871 elif isinstance(col, Alias) and isinstance(
872 col.element, TableClause
873 ):
874 # we are a SELECT statement and not derived from an alias of a
875 # table (which nonetheless may be a table our SELECT derives
876 # from), so return the alias to prevent further traversal
877 # or
878 # we are an alias of a table and we are not derived from an
879 # alias of a table (which nonetheless may be the same table
880 # as ours) so, same thing
881 return col
882 else:
883 # other cases where we are a selectable and the element
884 # is another join or selectable that contains a table which our
885 # selectable derives from, that we want to process
886 return None
887
888 elif not isinstance(col, ColumnElement):
889 return None
890 elif not _include_singleton_constants and col._is_singleton_constant:
891 # dont swap out NULL, TRUE, FALSE for a label name
892 # in a SQL statement that's being rewritten,
893 # leave them as the constant. This is first noted in #6259,
894 # however the logic to check this moved here as of #7154 so that
895 # it is made specific to SQL rewriting and not all column
896 # correspondence
897 return None
898
899 if "adapt_column" in col._annotations:
900 col = col._annotations["adapt_column"]
901
902 if self.adapt_from_selectables and col not in self.equivalents:
903 for adp in self.adapt_from_selectables:
904 if adp.c.corresponding_column(col, False) is not None:
905 break
906 else:
907 return None
908
909 if self.include_fn and not self.include_fn(col):
910 return None
911 elif self.exclude_fn and self.exclude_fn(col):
912 return None
913 else:
914 return self._corresponding_column(col, True)
915
916
917class ColumnAdapter(ClauseAdapter):
918 """Extends ClauseAdapter with extra utility functions.
919
920 Key aspects of ColumnAdapter include:
921
922 * Expressions that are adapted are stored in a persistent
923 .columns collection; so that an expression E adapted into
924 an expression E1, will return the same object E1 when adapted
925 a second time. This is important in particular for things like
926 Label objects that are anonymized, so that the ColumnAdapter can
927 be used to present a consistent "adapted" view of things.
928
929 * Exclusion of items from the persistent collection based on
930 include/exclude rules, but also independent of hash identity.
931 This because "annotated" items all have the same hash identity as their
932 parent.
933
934 * "wrapping" capability is added, so that the replacement of an expression
935 E can proceed through a series of adapters. This differs from the
936 visitor's "chaining" feature in that the resulting object is passed
937 through all replacing functions unconditionally, rather than stopping
938 at the first one that returns non-None.
939
940 * An adapt_required option, used by eager loading to indicate that
941 We don't trust a result row column that is not translated.
942 This is to prevent a column from being interpreted as that
943 of the child row in a self-referential scenario, see
944 inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
945
946 """
947
948 def __init__(
949 self,
950 selectable,
951 equivalents=None,
952 adapt_required=False,
953 include_fn=None,
954 exclude_fn=None,
955 adapt_on_names=False,
956 allow_label_resolve=True,
957 anonymize_labels=False,
958 adapt_from_selectables=None,
959 ):
960 ClauseAdapter.__init__(
961 self,
962 selectable,
963 equivalents,
964 include_fn=include_fn,
965 exclude_fn=exclude_fn,
966 adapt_on_names=adapt_on_names,
967 anonymize_labels=anonymize_labels,
968 adapt_from_selectables=adapt_from_selectables,
969 )
970
971 self.columns = util.WeakPopulateDict(self._locate_col)
972 if self.include_fn or self.exclude_fn:
973 self.columns = self._IncludeExcludeMapping(self, self.columns)
974 self.adapt_required = adapt_required
975 self.allow_label_resolve = allow_label_resolve
976 self._wrap = None
977
978 class _IncludeExcludeMapping(object):
979 def __init__(self, parent, columns):
980 self.parent = parent
981 self.columns = columns
982
983 def __getitem__(self, key):
984 if (
985 self.parent.include_fn and not self.parent.include_fn(key)
986 ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
987 if self.parent._wrap:
988 return self.parent._wrap.columns[key]
989 else:
990 return key
991 return self.columns[key]
992
993 def wrap(self, adapter):
994 ac = self.__class__.__new__(self.__class__)
995 ac.__dict__.update(self.__dict__)
996 ac._wrap = adapter
997 ac.columns = util.WeakPopulateDict(ac._locate_col)
998 if ac.include_fn or ac.exclude_fn:
999 ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
1000
1001 return ac
1002
1003 def traverse(self, obj):
1004 return self.columns[obj]
1005
1006 adapt_clause = traverse
1007 adapt_list = ClauseAdapter.copy_and_process
1008
1009 def adapt_check_present(self, col):
1010 newcol = self.columns[col]
1011
1012 if newcol is col and self._corresponding_column(col, True) is None:
1013 return None
1014
1015 return newcol
1016
1017 def _locate_col(self, col):
1018 # both replace and traverse() are overly complicated for what
1019 # we are doing here and we would do better to have an inlined
1020 # version that doesn't build up as much overhead. the issue is that
1021 # sometimes the lookup does in fact have to adapt the insides of
1022 # say a labeled scalar subquery. However, if the object is an
1023 # Immutable, i.e. Column objects, we can skip the "clone" /
1024 # "copy internals" part since those will be no-ops in any case.
1025 # additionally we want to catch singleton objects null/true/false
1026 # and make sure they are adapted as well here.
1027
1028 if col._is_immutable:
1029 for vis in self.visitor_iterator:
1030 c = vis.replace(col, _include_singleton_constants=True)
1031 if c is not None:
1032 break
1033 else:
1034 c = col
1035 else:
1036 c = ClauseAdapter.traverse(self, col)
1037
1038 if self._wrap:
1039 c2 = self._wrap._locate_col(c)
1040 if c2 is not None:
1041 c = c2
1042
1043 if self.adapt_required and c is col:
1044 return None
1045
1046 c._allow_label_resolve = self.allow_label_resolve
1047
1048 return c
1049
1050 def __getstate__(self):
1051 d = self.__dict__.copy()
1052 del d["columns"]
1053 return d
1054
1055 def __setstate__(self, state):
1056 self.__dict__.update(state)
1057 self.columns = util.WeakPopulateDict(self._locate_col)
1058
1059
1060def _offset_or_limit_clause(element, name=None, type_=None):
1061 """Convert the given value to an "offset or limit" clause.
1062
1063 This handles incoming integers and converts to an expression; if
1064 an expression is already given, it is passed through.
1065
1066 """
1067 return coercions.expect(
1068 roles.LimitOffsetRole, element, name=name, type_=type_
1069 )
1070
1071
1072def _offset_or_limit_clause_asint_if_possible(clause):
1073 """Return the offset or limit clause as a simple integer if possible,
1074 else return the clause.
1075
1076 """
1077 if clause is None:
1078 return None
1079 if hasattr(clause, "_limit_offset_value"):
1080 value = clause._limit_offset_value
1081 return util.asint(value)
1082 else:
1083 return clause
1084
1085
1086def _make_slice(limit_clause, offset_clause, start, stop):
1087 """Compute LIMIT/OFFSET in terms of slice start/end"""
1088
1089 # for calculated limit/offset, try to do the addition of
1090 # values to offset in Python, however if a SQL clause is present
1091 # then the addition has to be on the SQL side.
1092 if start is not None and stop is not None:
1093 offset_clause = _offset_or_limit_clause_asint_if_possible(
1094 offset_clause
1095 )
1096 if offset_clause is None:
1097 offset_clause = 0
1098
1099 if start != 0:
1100 offset_clause = offset_clause + start
1101
1102 if offset_clause == 0:
1103 offset_clause = None
1104 else:
1105 offset_clause = _offset_or_limit_clause(offset_clause)
1106
1107 limit_clause = _offset_or_limit_clause(stop - start)
1108
1109 elif start is None and stop is not None:
1110 limit_clause = _offset_or_limit_clause(stop)
1111 elif start is not None and stop is None:
1112 offset_clause = _offset_or_limit_clause_asint_if_possible(
1113 offset_clause
1114 )
1115 if offset_clause is None:
1116 offset_clause = 0
1117
1118 if start != 0:
1119 offset_clause = offset_clause + start
1120
1121 if offset_clause == 0:
1122 offset_clause = None
1123 else:
1124 offset_clause = _offset_or_limit_clause(offset_clause)
1125
1126 return limit_clause, offset_clause