1# sql/util.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""High level utilities which build upon other modules here.
9
10"""
11
12from collections import deque
13from itertools import chain
14
15from . import operators
16from . import visitors
17from .annotation import _deep_annotate # noqa
18from .annotation import _deep_deannotate # noqa
19from .annotation import _shallow_annotate # noqa
20from .base import _from_objects
21from .base import ColumnSet
22from .ddl import sort_tables # noqa
23from .elements import _expand_cloned
24from .elements import _find_columns # noqa
25from .elements import _label_reference
26from .elements import _textual_label_reference
27from .elements import BindParameter
28from .elements import ColumnClause
29from .elements import ColumnElement
30from .elements import Grouping
31from .elements import Label
32from .elements import Null
33from .elements import UnaryExpression
34from .schema import Column
35from .selectable import Alias
36from .selectable import FromClause
37from .selectable import FromGrouping
38from .selectable import Join
39from .selectable import ScalarSelect
40from .selectable import SelectBase
41from .selectable import TableClause
42from .. import exc
43from .. import util
44
45
46join_condition = util.langhelpers.public_factory(
47 Join._join_condition, ".sql.util.join_condition"
48)
49
50
51def find_join_source(clauses, join_to):
52 """Given a list of FROM clauses and a selectable,
53 return the first index and element from the list of
54 clauses which can be joined against the selectable. returns
55 None, None if no match is found.
56
57 e.g.::
58
59 clause1 = table1.join(table2)
60 clause2 = table4.join(table5)
61
62 join_to = table2.join(table3)
63
64 find_join_source([clause1, clause2], join_to) == clause1
65
66 """
67
68 selectables = list(_from_objects(join_to))
69 idx = []
70 for i, f in enumerate(clauses):
71 for s in selectables:
72 if f.is_derived_from(s):
73 idx.append(i)
74 return idx
75
76
77def find_left_clause_that_matches_given(clauses, join_from):
78 """Given a list of FROM clauses and a selectable,
79 return the indexes from the list of
80 clauses which is derived from the selectable.
81
82 """
83
84 selectables = list(_from_objects(join_from))
85 liberal_idx = []
86 for i, f in enumerate(clauses):
87 for s in selectables:
88 # basic check, if f is derived from s.
89 # this can be joins containing a table, or an aliased table
90 # or select statement matching to a table. This check
91 # will match a table to a selectable that is adapted from
92 # that table. With Query, this suits the case where a join
93 # is being made to an adapted entity
94 if f.is_derived_from(s):
95 liberal_idx.append(i)
96 break
97
98 # in an extremely small set of use cases, a join is being made where
99 # there are multiple FROM clauses where our target table is represented
100 # in more than one, such as embedded or similar. in this case, do
101 # another pass where we try to get a more exact match where we aren't
102 # looking at adaption relationships.
103 if len(liberal_idx) > 1:
104 conservative_idx = []
105 for idx in liberal_idx:
106 f = clauses[idx]
107 for s in selectables:
108 if set(surface_selectables(f)).intersection(
109 surface_selectables(s)
110 ):
111 conservative_idx.append(idx)
112 break
113 if conservative_idx:
114 return conservative_idx
115
116 return liberal_idx
117
118
119def find_left_clause_to_join_from(clauses, join_to, onclause):
120 """Given a list of FROM clauses, a selectable,
121 and optional ON clause, return a list of integer indexes from the
122 clauses list indicating the clauses that can be joined from.
123
124 The presence of an "onclause" indicates that at least one clause can
125 definitely be joined from; if the list of clauses is of length one
126 and the onclause is given, returns that index. If the list of clauses
127 is more than length one, and the onclause is given, attempts to locate
128 which clauses contain the same columns.
129
130 """
131 idx = []
132 selectables = set(_from_objects(join_to))
133
134 # if we are given more than one target clause to join
135 # from, use the onclause to provide a more specific answer.
136 # otherwise, don't try to limit, after all, "ON TRUE" is a valid
137 # on clause
138 if len(clauses) > 1 and onclause is not None:
139 resolve_ambiguity = True
140 cols_in_onclause = _find_columns(onclause)
141 else:
142 resolve_ambiguity = False
143 cols_in_onclause = None
144
145 for i, f in enumerate(clauses):
146 for s in selectables.difference([f]):
147 if resolve_ambiguity:
148 if set(f.c).union(s.c).issuperset(cols_in_onclause):
149 idx.append(i)
150 break
151 elif Join._can_join(f, s) or onclause is not None:
152 idx.append(i)
153 break
154
155 if len(idx) > 1:
156 # this is the same "hide froms" logic from
157 # Selectable._get_display_froms
158 toremove = set(
159 chain(*[_expand_cloned(f._hide_froms) for f in clauses])
160 )
161 idx = [i for i in idx if clauses[i] not in toremove]
162
163 # onclause was given and none of them resolved, so assume
164 # all indexes can match
165 if not idx and onclause is not None:
166 return range(len(clauses))
167 else:
168 return idx
169
170
171def visit_binary_product(fn, expr):
172 """Produce a traversal of the given expression, delivering
173 column comparisons to the given function.
174
175 The function is of the form::
176
177 def my_fn(binary, left, right)
178
179 For each binary expression located which has a
180 comparison operator, the product of "left" and
181 "right" will be delivered to that function,
182 in terms of that binary.
183
184 Hence an expression like::
185
186 and_(
187 (a + b) == q + func.sum(e + f),
188 j == r
189 )
190
191 would have the traversal::
192
193 a <eq> q
194 a <eq> e
195 a <eq> f
196 b <eq> q
197 b <eq> e
198 b <eq> f
199 j <eq> r
200
201 That is, every combination of "left" and
202 "right" that doesn't further contain
203 a binary comparison is passed as pairs.
204
205 """
206 stack = []
207
208 def visit(element):
209 if isinstance(element, ScalarSelect):
210 # we don't want to dig into correlated subqueries,
211 # those are just column elements by themselves
212 yield element
213 elif element.__visit_name__ == "binary" and operators.is_comparison(
214 element.operator
215 ):
216 stack.insert(0, element)
217 for l in visit(element.left):
218 for r in visit(element.right):
219 fn(stack[0], l, r)
220 stack.pop(0)
221 for elem in element.get_children():
222 visit(elem)
223 else:
224 if isinstance(element, ColumnClause):
225 yield element
226 for elem in element.get_children():
227 for e in visit(elem):
228 yield e
229
230 list(visit(expr))
231 visit = None # remove gc cycles
232
233
234def find_tables(
235 clause,
236 check_columns=False,
237 include_aliases=False,
238 include_joins=False,
239 include_selects=False,
240 include_crud=False,
241):
242 """locate Table objects within the given expression."""
243
244 tables = []
245 _visitors = {}
246
247 if include_selects:
248 _visitors["select"] = _visitors["compound_select"] = tables.append
249
250 if include_joins:
251 _visitors["join"] = tables.append
252
253 if include_aliases:
254 _visitors["alias"] = tables.append
255
256 if include_crud:
257 _visitors["insert"] = _visitors["update"] = _visitors[
258 "delete"
259 ] = lambda ent: tables.append(ent.table)
260
261 if check_columns:
262
263 def visit_column(column):
264 tables.append(column.table)
265
266 _visitors["column"] = visit_column
267
268 _visitors["table"] = tables.append
269
270 visitors.traverse(clause, {"column_collections": False}, _visitors)
271 return tables
272
273
274def unwrap_order_by(clause):
275 """Break up an 'order by' expression into individual column-expressions,
276 without DESC/ASC/NULLS FIRST/NULLS LAST"""
277
278 cols = util.column_set()
279 result = []
280 stack = deque([clause])
281 while stack:
282 t = stack.popleft()
283 if isinstance(t, ColumnElement) and (
284 not isinstance(t, UnaryExpression)
285 or not operators.is_ordering_modifier(t.modifier)
286 ):
287 if isinstance(t, Label) and not isinstance(
288 t.element, ScalarSelect
289 ):
290 t = t.element
291
292 if isinstance(t, Grouping):
293 t = t.element
294
295 stack.append(t)
296 continue
297
298 if isinstance(t, _label_reference):
299 t = t.element
300 if isinstance(t, (_textual_label_reference)):
301 continue
302 if t not in cols:
303 cols.add(t)
304 result.append(t)
305 else:
306 for c in t.get_children():
307 stack.append(c)
308 return result
309
310
311def unwrap_label_reference(element):
312 def replace(elem):
313 if isinstance(elem, (_label_reference, _textual_label_reference)):
314 return elem.element
315
316 return visitors.replacement_traverse(element, {}, replace)
317
318
319def expand_column_list_from_order_by(collist, order_by):
320 """Given the columns clause and ORDER BY of a selectable,
321 return a list of column expressions that can be added to the collist
322 corresponding to the ORDER BY, without repeating those already
323 in the collist.
324
325 """
326 cols_already_present = set(
327 [
328 col.element if col._order_by_label_element is not None else col
329 for col in collist
330 ]
331 )
332
333 to_look_for = list(chain(*[unwrap_order_by(o) for o in order_by]))
334
335 return [col for col in to_look_for if col not in cols_already_present]
336
337
338def clause_is_present(clause, search):
339 """Given a target clause and a second to search within, return True
340 if the target is plainly present in the search without any
341 subqueries or aliases involved.
342
343 Basically descends through Joins.
344
345 """
346
347 for elem in surface_selectables(search):
348 if clause == elem: # use == here so that Annotated's compare
349 return True
350 else:
351 return False
352
353
354def tables_from_leftmost(clause):
355 if isinstance(clause, Join):
356 for t in tables_from_leftmost(clause.left):
357 yield t
358 for t in tables_from_leftmost(clause.right):
359 yield t
360 elif isinstance(clause, FromGrouping):
361 for t in tables_from_leftmost(clause.element):
362 yield t
363 else:
364 yield clause
365
366
367def surface_selectables(clause):
368 stack = [clause]
369 while stack:
370 elem = stack.pop()
371 yield elem
372 if isinstance(elem, Join):
373 stack.extend((elem.left, elem.right))
374 elif isinstance(elem, FromGrouping):
375 stack.append(elem.element)
376
377
378def surface_selectables_only(clause):
379 stack = [clause]
380 while stack:
381 elem = stack.pop()
382 if isinstance(elem, (TableClause, Alias)):
383 yield elem
384 if isinstance(elem, Join):
385 stack.extend((elem.left, elem.right))
386 elif isinstance(elem, FromGrouping):
387 stack.append(elem.element)
388 elif isinstance(elem, ColumnClause):
389 if elem.table is not None:
390 stack.append(elem.table)
391 else:
392 yield elem
393 elif elem is not None:
394 yield elem
395
396
397def surface_column_elements(clause, include_scalar_selects=True):
398 """traverse and yield only outer-exposed column elements, such as would
399 be addressable in the WHERE clause of a SELECT if this element were
400 in the columns clause."""
401
402 filter_ = (FromGrouping,)
403 if not include_scalar_selects:
404 filter_ += (SelectBase,)
405
406 stack = deque([clause])
407 while stack:
408 elem = stack.popleft()
409 yield elem
410 for sub in elem.get_children():
411 if isinstance(sub, filter_):
412 continue
413 stack.append(sub)
414
415
416def selectables_overlap(left, right):
417 """Return True if left/right have some overlapping selectable"""
418
419 return bool(
420 set(surface_selectables(left)).intersection(surface_selectables(right))
421 )
422
423
424def bind_values(clause):
425 """Return an ordered list of "bound" values in the given clause.
426
427 E.g.::
428
429 >>> expr = and_(
430 ... table.c.foo==5, table.c.foo==7
431 ... )
432 >>> bind_values(expr)
433 [5, 7]
434 """
435
436 v = []
437
438 def visit_bindparam(bind):
439 v.append(bind.effective_value)
440
441 visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
442 return v
443
444
445def _quote_ddl_expr(element):
446 if isinstance(element, util.string_types):
447 element = element.replace("'", "''")
448 return "'%s'" % element
449 else:
450 return repr(element)
451
452
453class _repr_base(object):
454 _LIST = 0
455 _TUPLE = 1
456 _DICT = 2
457
458 __slots__ = ("max_chars",)
459
460 def trunc(self, value):
461 rep = repr(value)
462 lenrep = len(rep)
463 if lenrep > self.max_chars:
464 segment_length = self.max_chars // 2
465 rep = (
466 rep[0:segment_length]
467 + (
468 " ... (%d characters truncated) ... "
469 % (lenrep - self.max_chars)
470 )
471 + rep[-segment_length:]
472 )
473 return rep
474
475
476class _repr_row(_repr_base):
477 """Provide a string view of a row."""
478
479 __slots__ = ("row",)
480
481 def __init__(self, row, max_chars=300):
482 self.row = row
483 self.max_chars = max_chars
484
485 def __repr__(self):
486 trunc = self.trunc
487 return "(%s%s)" % (
488 ", ".join(trunc(value) for value in self.row),
489 "," if len(self.row) == 1 else "",
490 )
491
492
493class _repr_params(_repr_base):
494 """Provide a string view of bound parameters.
495
496 Truncates display to a given numnber of 'multi' parameter sets,
497 as well as long values to a given number of characters.
498
499 """
500
501 __slots__ = "params", "batches", "ismulti"
502
503 def __init__(self, params, batches, max_chars=300, ismulti=None):
504 self.params = params
505 self.ismulti = ismulti
506 self.batches = batches
507 self.max_chars = max_chars
508
509 def __repr__(self):
510 if self.ismulti is None:
511 return self.trunc(self.params)
512
513 if isinstance(self.params, list):
514 typ = self._LIST
515
516 elif isinstance(self.params, tuple):
517 typ = self._TUPLE
518 elif isinstance(self.params, dict):
519 typ = self._DICT
520 else:
521 return self.trunc(self.params)
522
523 if self.ismulti and len(self.params) > self.batches:
524 msg = " ... displaying %i of %i total bound parameter sets ... "
525 return " ".join(
526 (
527 self._repr_multi(self.params[: self.batches - 2], typ)[
528 0:-1
529 ],
530 msg % (self.batches, len(self.params)),
531 self._repr_multi(self.params[-2:], typ)[1:],
532 )
533 )
534 elif self.ismulti:
535 return self._repr_multi(self.params, typ)
536 else:
537 return self._repr_params(self.params, typ)
538
539 def _repr_multi(self, multi_params, typ):
540 if multi_params:
541 if isinstance(multi_params[0], list):
542 elem_type = self._LIST
543 elif isinstance(multi_params[0], tuple):
544 elem_type = self._TUPLE
545 elif isinstance(multi_params[0], dict):
546 elem_type = self._DICT
547 else:
548 assert False, "Unknown parameter type %s" % (
549 type(multi_params[0])
550 )
551
552 elements = ", ".join(
553 self._repr_params(params, elem_type) for params in multi_params
554 )
555 else:
556 elements = ""
557
558 if typ == self._LIST:
559 return "[%s]" % elements
560 else:
561 return "(%s)" % elements
562
563 def _repr_params(self, params, typ):
564 trunc = self.trunc
565 if typ is self._DICT:
566 return "{%s}" % (
567 ", ".join(
568 "%r: %s" % (key, trunc(value))
569 for key, value in params.items()
570 )
571 )
572 elif typ is self._TUPLE:
573 return "(%s%s)" % (
574 ", ".join(trunc(value) for value in params),
575 "," if len(params) == 1 else "",
576 )
577 else:
578 return "[%s]" % (", ".join(trunc(value) for value in params))
579
580
581def adapt_criterion_to_null(crit, nulls):
582 """given criterion containing bind params, convert selected elements
583 to IS NULL.
584
585 """
586
587 def visit_binary(binary):
588 if (
589 isinstance(binary.left, BindParameter)
590 and binary.left._identifying_key in nulls
591 ):
592 # reverse order if the NULL is on the left side
593 binary.left = binary.right
594 binary.right = Null()
595 binary.operator = operators.is_
596 binary.negate = operators.isnot
597 elif (
598 isinstance(binary.right, BindParameter)
599 and binary.right._identifying_key in nulls
600 ):
601 binary.right = Null()
602 binary.operator = operators.is_
603 binary.negate = operators.isnot
604
605 return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
606
607
608def splice_joins(left, right, stop_on=None):
609 if left is None:
610 return right
611
612 stack = [(right, None)]
613
614 adapter = ClauseAdapter(left)
615 ret = None
616 while stack:
617 (right, prevright) = stack.pop()
618 if isinstance(right, Join) and right is not stop_on:
619 right = right._clone()
620 right._reset_exported()
621 right.onclause = adapter.traverse(right.onclause)
622 stack.append((right.left, right))
623 else:
624 right = adapter.traverse(right)
625 if prevright is not None:
626 prevright.left = right
627 if ret is None:
628 ret = right
629
630 return ret
631
632
633def reduce_columns(columns, *clauses, **kw):
634 r"""given a list of columns, return a 'reduced' set based on natural
635 equivalents.
636
637 the set is reduced to the smallest list of columns which have no natural
638 equivalent present in the list. A "natural equivalent" means that two
639 columns will ultimately represent the same value because they are related
640 by a foreign key.
641
642 \*clauses is an optional list of join clauses which will be traversed
643 to further identify columns that are "equivalent".
644
645 \**kw may specify 'ignore_nonexistent_tables' to ignore foreign keys
646 whose tables are not yet configured, or columns that aren't yet present.
647
648 This function is primarily used to determine the most minimal "primary
649 key" from a selectable, by reducing the set of primary key columns present
650 in the selectable to just those that are not repeated.
651
652 """
653 ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
654 only_synonyms = kw.pop("only_synonyms", False)
655
656 columns = util.ordered_column_set(columns)
657
658 omit = util.column_set()
659 for col in columns:
660 for fk in chain(*[c.foreign_keys for c in col.proxy_set]):
661 for c in columns:
662 if c is col:
663 continue
664 try:
665 fk_col = fk.column
666 except exc.NoReferencedColumnError:
667 # TODO: add specific coverage here
668 # to test/sql/test_selectable ReduceTest
669 if ignore_nonexistent_tables:
670 continue
671 else:
672 raise
673 except exc.NoReferencedTableError:
674 # TODO: add specific coverage here
675 # to test/sql/test_selectable ReduceTest
676 if ignore_nonexistent_tables:
677 continue
678 else:
679 raise
680 if fk_col.shares_lineage(c) and (
681 not only_synonyms or c.name == col.name
682 ):
683 omit.add(col)
684 break
685
686 if clauses:
687
688 def visit_binary(binary):
689 if binary.operator == operators.eq:
690 cols = util.column_set(
691 chain(*[c.proxy_set for c in columns.difference(omit)])
692 )
693 if binary.left in cols and binary.right in cols:
694 for c in reversed(columns):
695 if c.shares_lineage(binary.right) and (
696 not only_synonyms or c.name == binary.left.name
697 ):
698 omit.add(c)
699 break
700
701 for clause in clauses:
702 if clause is not None:
703 visitors.traverse(clause, {}, {"binary": visit_binary})
704
705 return ColumnSet(columns.difference(omit))
706
707
708def criterion_as_pairs(
709 expression,
710 consider_as_foreign_keys=None,
711 consider_as_referenced_keys=None,
712 any_operator=False,
713):
714 """traverse an expression and locate binary criterion pairs."""
715
716 if consider_as_foreign_keys and consider_as_referenced_keys:
717 raise exc.ArgumentError(
718 "Can only specify one of "
719 "'consider_as_foreign_keys' or "
720 "'consider_as_referenced_keys'"
721 )
722
723 def col_is(a, b):
724 # return a is b
725 return a.compare(b)
726
727 def visit_binary(binary):
728 if not any_operator and binary.operator is not operators.eq:
729 return
730 if not isinstance(binary.left, ColumnElement) or not isinstance(
731 binary.right, ColumnElement
732 ):
733 return
734
735 if consider_as_foreign_keys:
736 if binary.left in consider_as_foreign_keys and (
737 col_is(binary.right, binary.left)
738 or binary.right not in consider_as_foreign_keys
739 ):
740 pairs.append((binary.right, binary.left))
741 elif binary.right in consider_as_foreign_keys and (
742 col_is(binary.left, binary.right)
743 or binary.left not in consider_as_foreign_keys
744 ):
745 pairs.append((binary.left, binary.right))
746 elif consider_as_referenced_keys:
747 if binary.left in consider_as_referenced_keys and (
748 col_is(binary.right, binary.left)
749 or binary.right not in consider_as_referenced_keys
750 ):
751 pairs.append((binary.left, binary.right))
752 elif binary.right in consider_as_referenced_keys and (
753 col_is(binary.left, binary.right)
754 or binary.left not in consider_as_referenced_keys
755 ):
756 pairs.append((binary.right, binary.left))
757 else:
758 if isinstance(binary.left, Column) and isinstance(
759 binary.right, Column
760 ):
761 if binary.left.references(binary.right):
762 pairs.append((binary.right, binary.left))
763 elif binary.right.references(binary.left):
764 pairs.append((binary.left, binary.right))
765
766 pairs = []
767 visitors.traverse(expression, {}, {"binary": visit_binary})
768 return pairs
769
770
771class ClauseAdapter(visitors.ReplacingCloningVisitor):
772 """Clones and modifies clauses based on column correspondence.
773
774 E.g.::
775
776 table1 = Table('sometable', metadata,
777 Column('col1', Integer),
778 Column('col2', Integer)
779 )
780 table2 = Table('someothertable', metadata,
781 Column('col1', Integer),
782 Column('col2', Integer)
783 )
784
785 condition = table1.c.col1 == table2.c.col1
786
787 make an alias of table1::
788
789 s = table1.alias('foo')
790
791 calling ``ClauseAdapter(s).traverse(condition)`` converts
792 condition to read::
793
794 s.c.col1 == table2.c.col1
795
796 """
797
798 def __init__(
799 self,
800 selectable,
801 equivalents=None,
802 include_fn=None,
803 exclude_fn=None,
804 adapt_on_names=False,
805 anonymize_labels=False,
806 ):
807 self.__traverse_options__ = {
808 "stop_on": [selectable],
809 "anonymize_labels": anonymize_labels,
810 }
811 self.selectable = selectable
812 self.include_fn = include_fn
813 self.exclude_fn = exclude_fn
814 self.equivalents = util.column_dict(equivalents or {})
815 self.adapt_on_names = adapt_on_names
816
817 def _corresponding_column(
818 self, col, require_embedded, _seen=util.EMPTY_SET
819 ):
820 newcol = self.selectable.corresponding_column(
821 col, require_embedded=require_embedded
822 )
823 if newcol is None and col in self.equivalents and col not in _seen:
824 for equiv in self.equivalents[col]:
825 newcol = self._corresponding_column(
826 equiv,
827 require_embedded=require_embedded,
828 _seen=_seen.union([col]),
829 )
830 if newcol is not None:
831 return newcol
832 if self.adapt_on_names and newcol is None:
833 newcol = self.selectable.c.get(col.name)
834 return newcol
835
836 def replace(self, col):
837 if isinstance(col, FromClause) and self.selectable.is_derived_from(
838 col
839 ):
840 return self.selectable
841 elif not isinstance(col, ColumnElement):
842 return None
843 elif self.include_fn and not self.include_fn(col):
844 return None
845 elif self.exclude_fn and self.exclude_fn(col):
846 return None
847 else:
848 return self._corresponding_column(col, True)
849
850
851class ColumnAdapter(ClauseAdapter):
852 """Extends ClauseAdapter with extra utility functions.
853
854 Key aspects of ColumnAdapter include:
855
856 * Expressions that are adapted are stored in a persistent
857 .columns collection; so that an expression E adapted into
858 an expression E1, will return the same object E1 when adapted
859 a second time. This is important in particular for things like
860 Label objects that are anonymized, so that the ColumnAdapter can
861 be used to present a consistent "adapted" view of things.
862
863 * Exclusion of items from the persistent collection based on
864 include/exclude rules, but also independent of hash identity.
865 This because "annotated" items all have the same hash identity as their
866 parent.
867
868 * "wrapping" capability is added, so that the replacement of an expression
869 E can proceed through a series of adapters. This differs from the
870 visitor's "chaining" feature in that the resulting object is passed
871 through all replacing functions unconditionally, rather than stopping
872 at the first one that returns non-None.
873
874 * An adapt_required option, used by eager loading to indicate that
875 We don't trust a result row column that is not translated.
876 This is to prevent a column from being interpreted as that
877 of the child row in a self-referential scenario, see
878 inheritance/test_basic.py->EagerTargetingTest.test_adapt_stringency
879
880 """
881
882 def __init__(
883 self,
884 selectable,
885 equivalents=None,
886 adapt_required=False,
887 include_fn=None,
888 exclude_fn=None,
889 adapt_on_names=False,
890 allow_label_resolve=True,
891 anonymize_labels=False,
892 ):
893 ClauseAdapter.__init__(
894 self,
895 selectable,
896 equivalents,
897 include_fn=include_fn,
898 exclude_fn=exclude_fn,
899 adapt_on_names=adapt_on_names,
900 anonymize_labels=anonymize_labels,
901 )
902
903 self.columns = util.WeakPopulateDict(self._locate_col)
904 if self.include_fn or self.exclude_fn:
905 self.columns = self._IncludeExcludeMapping(self, self.columns)
906 self.adapt_required = adapt_required
907 self.allow_label_resolve = allow_label_resolve
908 self._wrap = None
909
910 class _IncludeExcludeMapping(object):
911 def __init__(self, parent, columns):
912 self.parent = parent
913 self.columns = columns
914
915 def __getitem__(self, key):
916 if (
917 self.parent.include_fn and not self.parent.include_fn(key)
918 ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
919 if self.parent._wrap:
920 return self.parent._wrap.columns[key]
921 else:
922 return key
923 return self.columns[key]
924
925 def wrap(self, adapter):
926 ac = self.__class__.__new__(self.__class__)
927 ac.__dict__.update(self.__dict__)
928 ac._wrap = adapter
929 ac.columns = util.WeakPopulateDict(ac._locate_col)
930 if ac.include_fn or ac.exclude_fn:
931 ac.columns = self._IncludeExcludeMapping(ac, ac.columns)
932
933 return ac
934
935 def traverse(self, obj):
936 return self.columns[obj]
937
938 adapt_clause = traverse
939 adapt_list = ClauseAdapter.copy_and_process
940
941 def _locate_col(self, col):
942
943 c = ClauseAdapter.traverse(self, col)
944
945 if self._wrap:
946 c2 = self._wrap._locate_col(c)
947 if c2 is not None:
948 c = c2
949
950 if self.adapt_required and c is col:
951 return None
952
953 c._allow_label_resolve = self.allow_label_resolve
954
955 return c
956
957 def __getstate__(self):
958 d = self.__dict__.copy()
959 del d["columns"]
960 return d
961
962 def __setstate__(self, state):
963 self.__dict__.update(state)
964 self.columns = util.WeakPopulateDict(self._locate_col)