1# sql/compiler.py
2# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7
8"""Base SQL and DDL compiler implementations.
9
10Classes provided include:
11
12:class:`.compiler.SQLCompiler` - renders SQL
13strings
14
15:class:`.compiler.DDLCompiler` - renders DDL
16(data definition language) strings
17
18:class:`.compiler.GenericTypeCompiler` - renders
19type specification strings.
20
21To generate user-defined SQL strings, see
22:doc:`/ext/compiler`.
23
24"""
25
26import collections
27import contextlib
28import itertools
29import operator
30import re
31
32from . import base
33from . import coercions
34from . import crud
35from . import elements
36from . import functions
37from . import operators
38from . import schema
39from . import selectable
40from . import sqltypes
41from . import util as sql_util
42from .base import NO_ARG
43from .base import prefix_anon_map
44from .elements import quoted_name
45from .. import exc
46from .. import util
47
48RESERVED_WORDS = set(
49 [
50 "all",
51 "analyse",
52 "analyze",
53 "and",
54 "any",
55 "array",
56 "as",
57 "asc",
58 "asymmetric",
59 "authorization",
60 "between",
61 "binary",
62 "both",
63 "case",
64 "cast",
65 "check",
66 "collate",
67 "column",
68 "constraint",
69 "create",
70 "cross",
71 "current_date",
72 "current_role",
73 "current_time",
74 "current_timestamp",
75 "current_user",
76 "default",
77 "deferrable",
78 "desc",
79 "distinct",
80 "do",
81 "else",
82 "end",
83 "except",
84 "false",
85 "for",
86 "foreign",
87 "freeze",
88 "from",
89 "full",
90 "grant",
91 "group",
92 "having",
93 "ilike",
94 "in",
95 "initially",
96 "inner",
97 "intersect",
98 "into",
99 "is",
100 "isnull",
101 "join",
102 "leading",
103 "left",
104 "like",
105 "limit",
106 "localtime",
107 "localtimestamp",
108 "natural",
109 "new",
110 "not",
111 "notnull",
112 "null",
113 "off",
114 "offset",
115 "old",
116 "on",
117 "only",
118 "or",
119 "order",
120 "outer",
121 "overlaps",
122 "placing",
123 "primary",
124 "references",
125 "right",
126 "select",
127 "session_user",
128 "set",
129 "similar",
130 "some",
131 "symmetric",
132 "table",
133 "then",
134 "to",
135 "trailing",
136 "true",
137 "union",
138 "unique",
139 "user",
140 "using",
141 "verbose",
142 "when",
143 "where",
144 ]
145)
146
147LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
148LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
149ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
150
151FK_ON_DELETE = re.compile(
152 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
153)
154FK_ON_UPDATE = re.compile(
155 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
156)
157FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
158BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
159BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
160
161BIND_TEMPLATES = {
162 "pyformat": "%%(%(name)s)s",
163 "qmark": "?",
164 "format": "%%s",
165 "numeric": ":[_POSITION]",
166 "named": ":%(name)s",
167}
168
169_BIND_TRANSLATE_RE = re.compile(r"[%\(\):\[\] ]")
170_BIND_TRANSLATE_CHARS = dict(zip("%():[] ", "PAZC___"))
171
172OPERATORS = {
173 # binary
174 operators.and_: " AND ",
175 operators.or_: " OR ",
176 operators.add: " + ",
177 operators.mul: " * ",
178 operators.sub: " - ",
179 operators.div: " / ",
180 operators.mod: " % ",
181 operators.truediv: " / ",
182 operators.neg: "-",
183 operators.lt: " < ",
184 operators.le: " <= ",
185 operators.ne: " != ",
186 operators.gt: " > ",
187 operators.ge: " >= ",
188 operators.eq: " = ",
189 operators.is_distinct_from: " IS DISTINCT FROM ",
190 operators.is_not_distinct_from: " IS NOT DISTINCT FROM ",
191 operators.concat_op: " || ",
192 operators.match_op: " MATCH ",
193 operators.not_match_op: " NOT MATCH ",
194 operators.in_op: " IN ",
195 operators.not_in_op: " NOT IN ",
196 operators.comma_op: ", ",
197 operators.from_: " FROM ",
198 operators.as_: " AS ",
199 operators.is_: " IS ",
200 operators.is_not: " IS NOT ",
201 operators.collate: " COLLATE ",
202 # unary
203 operators.exists: "EXISTS ",
204 operators.distinct_op: "DISTINCT ",
205 operators.inv: "NOT ",
206 operators.any_op: "ANY ",
207 operators.all_op: "ALL ",
208 # modifiers
209 operators.desc_op: " DESC",
210 operators.asc_op: " ASC",
211 operators.nulls_first_op: " NULLS FIRST",
212 operators.nulls_last_op: " NULLS LAST",
213}
214
215FUNCTIONS = {
216 functions.coalesce: "coalesce",
217 functions.current_date: "CURRENT_DATE",
218 functions.current_time: "CURRENT_TIME",
219 functions.current_timestamp: "CURRENT_TIMESTAMP",
220 functions.current_user: "CURRENT_USER",
221 functions.localtime: "LOCALTIME",
222 functions.localtimestamp: "LOCALTIMESTAMP",
223 functions.random: "random",
224 functions.sysdate: "sysdate",
225 functions.session_user: "SESSION_USER",
226 functions.user: "USER",
227 functions.cube: "CUBE",
228 functions.rollup: "ROLLUP",
229 functions.grouping_sets: "GROUPING SETS",
230}
231
232EXTRACT_MAP = {
233 "month": "month",
234 "day": "day",
235 "year": "year",
236 "second": "second",
237 "hour": "hour",
238 "doy": "doy",
239 "minute": "minute",
240 "quarter": "quarter",
241 "dow": "dow",
242 "week": "week",
243 "epoch": "epoch",
244 "milliseconds": "milliseconds",
245 "microseconds": "microseconds",
246 "timezone_hour": "timezone_hour",
247 "timezone_minute": "timezone_minute",
248}
249
250COMPOUND_KEYWORDS = {
251 selectable.CompoundSelect.UNION: "UNION",
252 selectable.CompoundSelect.UNION_ALL: "UNION ALL",
253 selectable.CompoundSelect.EXCEPT: "EXCEPT",
254 selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL",
255 selectable.CompoundSelect.INTERSECT: "INTERSECT",
256 selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL",
257}
258
259
260RM_RENDERED_NAME = 0
261RM_NAME = 1
262RM_OBJECTS = 2
263RM_TYPE = 3
264
265
266ExpandedState = collections.namedtuple(
267 "ExpandedState",
268 [
269 "statement",
270 "additional_parameters",
271 "processors",
272 "positiontup",
273 "parameter_expansion",
274 ],
275)
276
277
278NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0)
279
280COLLECT_CARTESIAN_PRODUCTS = util.symbol(
281 "COLLECT_CARTESIAN_PRODUCTS",
282 "Collect data on FROMs and cartesian products and gather "
283 "into 'self.from_linter'",
284 canonical=1,
285)
286
287WARN_LINTING = util.symbol(
288 "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2
289)
290
291FROM_LINTING = util.symbol(
292 "FROM_LINTING",
293 "Warn for cartesian products; "
294 "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING",
295 canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING,
296)
297
298
299class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
300 def lint(self, start=None):
301 froms = self.froms
302 if not froms:
303 return None, None
304
305 edges = set(self.edges)
306 the_rest = set(froms)
307
308 if start is not None:
309 start_with = start
310 the_rest.remove(start_with)
311 else:
312 start_with = the_rest.pop()
313
314 stack = collections.deque([start_with])
315
316 while stack and the_rest:
317 node = stack.popleft()
318 the_rest.discard(node)
319
320 # comparison of nodes in edges here is based on hash equality, as
321 # there are "annotated" elements that match the non-annotated ones.
322 # to remove the need for in-python hash() calls, use native
323 # containment routines (e.g. "node in edge", "edge.index(node)")
324 to_remove = {edge for edge in edges if node in edge}
325
326 # appendleft the node in each edge that is not
327 # the one that matched.
328 stack.extendleft(edge[not edge.index(node)] for edge in to_remove)
329 edges.difference_update(to_remove)
330
331 # FROMS left over? boom
332 if the_rest:
333 return the_rest, start_with
334 else:
335 return None, None
336
337 def warn(self):
338 the_rest, start_with = self.lint()
339
340 # FROMS left over? boom
341 if the_rest:
342
343 froms = the_rest
344 if froms:
345 template = (
346 "SELECT statement has a cartesian product between "
347 "FROM element(s) {froms} and "
348 'FROM element "{start}". Apply join condition(s) '
349 "between each element to resolve."
350 )
351 froms_str = ", ".join(
352 '"{elem}"'.format(elem=self.froms[from_])
353 for from_ in froms
354 )
355 message = template.format(
356 froms=froms_str, start=self.froms[start_with]
357 )
358
359 util.warn(message)
360
361
362class Compiled(object):
363
364 """Represent a compiled SQL or DDL expression.
365
366 The ``__str__`` method of the ``Compiled`` object should produce
367 the actual text of the statement. ``Compiled`` objects are
368 specific to their underlying database dialect, and also may
369 or may not be specific to the columns referenced within a
370 particular set of bind parameters. In no case should the
371 ``Compiled`` object be dependent on the actual values of those
372 bind parameters, even though it may reference those values as
373 defaults.
374 """
375
376 _cached_metadata = None
377
378 _result_columns = None
379
380 schema_translate_map = None
381
382 execution_options = util.EMPTY_DICT
383 """
384 Execution options propagated from the statement. In some cases,
385 sub-elements of the statement can modify these.
386 """
387
388 _annotations = util.EMPTY_DICT
389
390 compile_state = None
391 """Optional :class:`.CompileState` object that maintains additional
392 state used by the compiler.
393
394 Major executable objects such as :class:`_expression.Insert`,
395 :class:`_expression.Update`, :class:`_expression.Delete`,
396 :class:`_expression.Select` will generate this
397 state when compiled in order to calculate additional information about the
398 object. For the top level object that is to be executed, the state can be
399 stored here where it can also have applicability towards result set
400 processing.
401
402 .. versionadded:: 1.4
403
404 """
405
406 dml_compile_state = None
407 """Optional :class:`.CompileState` assigned at the same point that
408 .isinsert, .isupdate, or .isdelete is assigned.
409
410 This will normally be the same object as .compile_state, with the
411 exception of cases like the :class:`.ORMFromStatementCompileState`
412 object.
413
414 .. versionadded:: 1.4.40
415
416 """
417
418 cache_key = None
419 _gen_time = None
420
421 def __init__(
422 self,
423 dialect,
424 statement,
425 schema_translate_map=None,
426 render_schema_translate=False,
427 compile_kwargs=util.immutabledict(),
428 ):
429 """Construct a new :class:`.Compiled` object.
430
431 :param dialect: :class:`.Dialect` to compile against.
432
433 :param statement: :class:`_expression.ClauseElement` to be compiled.
434
435 :param schema_translate_map: dictionary of schema names to be
436 translated when forming the resultant SQL
437
438 .. versionadded:: 1.1
439
440 .. seealso::
441
442 :ref:`schema_translating`
443
444 :param compile_kwargs: additional kwargs that will be
445 passed to the initial call to :meth:`.Compiled.process`.
446
447
448 """
449
450 self.dialect = dialect
451 self.preparer = self.dialect.identifier_preparer
452 if schema_translate_map:
453 self.schema_translate_map = schema_translate_map
454 self.preparer = self.preparer._with_schema_translate(
455 schema_translate_map
456 )
457
458 if statement is not None:
459 self.statement = statement
460 self.can_execute = statement.supports_execution
461 self._annotations = statement._annotations
462 if self.can_execute:
463 self.execution_options = statement._execution_options
464 self.string = self.process(self.statement, **compile_kwargs)
465
466 if render_schema_translate:
467 self.string = self.preparer._render_schema_translates(
468 self.string, schema_translate_map
469 )
470 self._gen_time = util.perf_counter()
471
472 def _execute_on_connection(
473 self, connection, multiparams, params, execution_options
474 ):
475 if self.can_execute:
476 return connection._execute_compiled(
477 self, multiparams, params, execution_options
478 )
479 else:
480 raise exc.ObjectNotExecutableError(self.statement)
481
482 def visit_unsupported_compilation(self, element, err):
483 util.raise_(
484 exc.UnsupportedCompilationError(self, type(element)),
485 replace_context=err,
486 )
487
488 @property
489 def sql_compiler(self):
490 """Return a Compiled that is capable of processing SQL expressions.
491
492 If this compiler is one, it would likely just return 'self'.
493
494 """
495
496 raise NotImplementedError()
497
498 def process(self, obj, **kwargs):
499 return obj._compiler_dispatch(self, **kwargs)
500
501 def __str__(self):
502 """Return the string text of the generated SQL or DDL."""
503
504 return self.string or ""
505
506 def construct_params(
507 self, params=None, extracted_parameters=None, escape_names=True
508 ):
509 """Return the bind params for this compiled object.
510
511 :param params: a dict of string/object pairs whose values will
512 override bind values compiled in to the
513 statement.
514 """
515
516 raise NotImplementedError()
517
518 @property
519 def params(self):
520 """Return the bind params for this compiled object."""
521 return self.construct_params()
522
523
524class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
525 """Produces DDL specification for TypeEngine objects."""
526
527 ensure_kwarg = r"visit_\w+"
528
529 def __init__(self, dialect):
530 self.dialect = dialect
531
532 def process(self, type_, **kw):
533 return type_._compiler_dispatch(self, **kw)
534
535 def visit_unsupported_compilation(self, element, err, **kw):
536 util.raise_(
537 exc.UnsupportedCompilationError(self, element),
538 replace_context=err,
539 )
540
541
542# this was a Visitable, but to allow accurate detection of
543# column elements this is actually a column element
544class _CompileLabel(elements.ColumnElement):
545
546 """lightweight label object which acts as an expression.Label."""
547
548 __visit_name__ = "label"
549 __slots__ = "element", "name"
550
551 def __init__(self, col, name, alt_names=()):
552 self.element = col
553 self.name = name
554 self._alt_names = (col,) + alt_names
555
556 @property
557 def proxy_set(self):
558 return self.element.proxy_set
559
560 @property
561 def type(self):
562 return self.element.type
563
564 def self_group(self, **kw):
565 return self
566
567
568class SQLCompiler(Compiled):
569 """Default implementation of :class:`.Compiled`.
570
571 Compiles :class:`_expression.ClauseElement` objects into SQL strings.
572
573 """
574
575 extract_map = EXTRACT_MAP
576
577 compound_keywords = COMPOUND_KEYWORDS
578
579 isdelete = isinsert = isupdate = False
580 """class-level defaults which can be set at the instance
581 level to define if this Compiled instance represents
582 INSERT/UPDATE/DELETE
583 """
584
585 isplaintext = False
586
587 returning = None
588 """holds the "returning" collection of columns if
589 the statement is CRUD and defines returning columns
590 either implicitly or explicitly
591 """
592
593 returning_precedes_values = False
594 """set to True classwide to generate RETURNING
595 clauses before the VALUES or WHERE clause (i.e. MSSQL)
596 """
597
598 render_table_with_column_in_update_from = False
599 """set to True classwide to indicate the SET clause
600 in a multi-table UPDATE statement should qualify
601 columns with the table name (i.e. MySQL only)
602 """
603
604 ansi_bind_rules = False
605 """SQL 92 doesn't allow bind parameters to be used
606 in the columns clause of a SELECT, nor does it allow
607 ambiguous expressions like "? = ?". A compiler
608 subclass can set this flag to False if the target
609 driver/DB enforces this
610 """
611
612 _textual_ordered_columns = False
613 """tell the result object that the column names as rendered are important,
614 but they are also "ordered" vs. what is in the compiled object here.
615
616 As of 1.4.42 this condition is only present when the statement is a
617 TextualSelect, e.g. text("....").columns(...), where it is required
618 that the columns are considered positionally and not by name.
619
620 """
621
622 _ad_hoc_textual = False
623 """tell the result that we encountered text() or '*' constructs in the
624 middle of the result columns, but we also have compiled columns, so
625 if the number of columns in cursor.description does not match how many
626 expressions we have, that means we can't rely on positional at all and
627 should match on name.
628
629 """
630
631 _ordered_columns = True
632 """
633 if False, means we can't be sure the list of entries
634 in _result_columns is actually the rendered order. Usually
635 True unless using an unordered TextualSelect.
636 """
637
638 _loose_column_name_matching = False
639 """tell the result object that the SQL statement is textual, wants to match
640 up to Column objects, and may be using the ._tq_label in the SELECT rather
641 than the base name.
642
643 """
644
645 _numeric_binds = False
646 """
647 True if paramstyle is "numeric". This paramstyle is trickier than
648 all the others.
649
650 """
651
652 _render_postcompile = False
653 """
654 whether to render out POSTCOMPILE params during the compile phase.
655
656 """
657
658 insert_single_values_expr = None
659 """When an INSERT is compiled with a single set of parameters inside
660 a VALUES expression, the string is assigned here, where it can be
661 used for insert batching schemes to rewrite the VALUES expression.
662
663 .. versionadded:: 1.3.8
664
665 """
666
667 literal_execute_params = frozenset()
668 """bindparameter objects that are rendered as literal values at statement
669 execution time.
670
671 """
672
673 post_compile_params = frozenset()
674 """bindparameter objects that are rendered as bound parameter placeholders
675 at statement execution time.
676
677 """
678
679 escaped_bind_names = util.EMPTY_DICT
680 """Late escaping of bound parameter names that has to be converted
681 to the original name when looking in the parameter dictionary.
682
683 """
684
685 has_out_parameters = False
686 """if True, there are bindparam() objects that have the isoutparam
687 flag set."""
688
689 insert_prefetch = update_prefetch = ()
690
691 postfetch_lastrowid = False
692 """if True, and this in insert, use cursor.lastrowid to populate
693 result.inserted_primary_key. """
694
695 _cache_key_bind_match = None
696 """a mapping that will relate the BindParameter object we compile
697 to those that are part of the extracted collection of parameters
698 in the cache key, if we were given a cache key.
699
700 """
701
702 _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]")
703
704 positiontup = None
705 """for a compiled construct that uses a positional paramstyle, will be
706 a sequence of strings, indicating the names of bound parameters in order.
707
708 This is used in order to render bound parameters in their correct order,
709 and is combined with the :attr:`_sql.Compiled.params` dictionary to
710 render parameters.
711
712 .. seealso::
713
714 :ref:`faq_sql_expression_string` - includes a usage example for
715 debugging use cases.
716
717 """
718 positiontup_level = None
719
720 inline = False
721
722 def __init__(
723 self,
724 dialect,
725 statement,
726 cache_key=None,
727 column_keys=None,
728 for_executemany=False,
729 linting=NO_LINTING,
730 **kwargs
731 ):
732 """Construct a new :class:`.SQLCompiler` object.
733
734 :param dialect: :class:`.Dialect` to be used
735
736 :param statement: :class:`_expression.ClauseElement` to be compiled
737
738 :param column_keys: a list of column names to be compiled into an
739 INSERT or UPDATE statement.
740
741 :param for_executemany: whether INSERT / UPDATE statements should
742 expect that they are to be invoked in an "executemany" style,
743 which may impact how the statement will be expected to return the
744 values of defaults and autoincrement / sequences and similar.
745 Depending on the backend and driver in use, support for retrieving
746 these values may be disabled which means SQL expressions may
747 be rendered inline, RETURNING may not be rendered, etc.
748
749 :param kwargs: additional keyword arguments to be consumed by the
750 superclass.
751
752 """
753 self.column_keys = column_keys
754
755 self.cache_key = cache_key
756
757 if cache_key:
758 self._cache_key_bind_match = ckbm = {
759 b.key: b for b in cache_key[1]
760 }
761 ckbm.update({b: [b] for b in cache_key[1]})
762
763 # compile INSERT/UPDATE defaults/sequences to expect executemany
764 # style execution, which may mean no pre-execute of defaults,
765 # or no RETURNING
766 self.for_executemany = for_executemany
767
768 self.linting = linting
769
770 # a dictionary of bind parameter keys to BindParameter
771 # instances.
772 self.binds = {}
773
774 # a dictionary of BindParameter instances to "compiled" names
775 # that are actually present in the generated SQL
776 self.bind_names = util.column_dict()
777
778 # stack which keeps track of nested SELECT statements
779 self.stack = []
780
781 # relates label names in the final SQL to a tuple of local
782 # column/label name, ColumnElement object (if any) and
783 # TypeEngine. CursorResult uses this for type processing and
784 # column targeting
785 self._result_columns = []
786
787 # true if the paramstyle is positional
788 self.positional = dialect.positional
789 if self.positional:
790 self.positiontup_level = {}
791 self.positiontup = []
792 self._numeric_binds = dialect.paramstyle == "numeric"
793 self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
794
795 self.ctes = None
796
797 self.label_length = (
798 dialect.label_length or dialect.max_identifier_length
799 )
800
801 # a map which tracks "anonymous" identifiers that are created on
802 # the fly here
803 self.anon_map = prefix_anon_map()
804
805 # a map which tracks "truncated" names based on
806 # dialect.label_length or dialect.max_identifier_length
807 self.truncated_names = {}
808
809 Compiled.__init__(self, dialect, statement, **kwargs)
810
811 if self.isinsert or self.isupdate or self.isdelete:
812 if statement._returning:
813 self.returning = statement._returning
814
815 if self.isinsert or self.isupdate:
816 if statement._inline:
817 self.inline = True
818 elif self.for_executemany and (
819 not self.isinsert
820 or (
821 self.dialect.insert_executemany_returning
822 and statement._return_defaults
823 )
824 ):
825 self.inline = True
826
827 if self.positional and self._numeric_binds:
828 self._apply_numbered_params()
829
830 if self._render_postcompile:
831 self._process_parameters_for_postcompile(_populate_self=True)
832
833 @property
834 def current_executable(self):
835 """Return the current 'executable' that is being compiled.
836
837 This is currently the :class:`_sql.Select`, :class:`_sql.Insert`,
838 :class:`_sql.Update`, :class:`_sql.Delete`,
839 :class:`_sql.CompoundSelect` object that is being compiled.
840 Specifically it's assigned to the ``self.stack`` list of elements.
841
842 When a statement like the above is being compiled, it normally
843 is also assigned to the ``.statement`` attribute of the
844 :class:`_sql.Compiler` object. However, all SQL constructs are
845 ultimately nestable, and this attribute should never be consulted
846 by a ``visit_`` method, as it is not guaranteed to be assigned
847 nor guaranteed to correspond to the current statement being compiled.
848
849 .. versionadded:: 1.3.21
850
851 For compatibility with previous versions, use the following
852 recipe::
853
854 statement = getattr(self, "current_executable", False)
855 if statement is False:
856 statement = self.stack[-1]["selectable"]
857
858 For versions 1.4 and above, ensure only .current_executable
859 is used; the format of "self.stack" may change.
860
861
862 """
863 try:
864 return self.stack[-1]["selectable"]
865 except IndexError as ie:
866 util.raise_(
867 IndexError("Compiler does not have a stack entry"),
868 replace_context=ie,
869 )
870
871 @property
872 def prefetch(self):
873 return list(self.insert_prefetch + self.update_prefetch)
874
875 @util.memoized_property
876 def _global_attributes(self):
877 return {}
878
879 @util.memoized_instancemethod
880 def _init_cte_state(self):
881 """Initialize collections related to CTEs only if
882 a CTE is located, to save on the overhead of
883 these collections otherwise.
884
885 """
886 # collect CTEs to tack on top of a SELECT
887 # To store the query to print - Dict[cte, text_query]
888 self.ctes = util.OrderedDict()
889
890 # Detect same CTE references - Dict[(level, name), cte]
891 # Level is required for supporting nesting
892 self.ctes_by_level_name = {}
893
894 # To retrieve key/level in ctes_by_level_name -
895 # Dict[cte_reference, (level, cte_name)]
896 self.level_name_by_cte = {}
897
898 self.ctes_recursive = False
899 if self.positional:
900 self.cte_positional = {}
901 self.cte_level = {}
902 self.cte_order = collections.defaultdict(list)
903
904 @contextlib.contextmanager
905 def _nested_result(self):
906 """special API to support the use case of 'nested result sets'"""
907 result_columns, ordered_columns = (
908 self._result_columns,
909 self._ordered_columns,
910 )
911 self._result_columns, self._ordered_columns = [], False
912
913 try:
914 if self.stack:
915 entry = self.stack[-1]
916 entry["need_result_map_for_nested"] = True
917 else:
918 entry = None
919 yield self._result_columns, self._ordered_columns
920 finally:
921 if entry:
922 entry.pop("need_result_map_for_nested")
923 self._result_columns, self._ordered_columns = (
924 result_columns,
925 ordered_columns,
926 )
927
928 def _apply_numbered_params(self):
929 poscount = itertools.count(1)
930 self.string = re.sub(
931 r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string
932 )
933
934 @util.memoized_property
935 def _bind_processors(self):
936
937 return dict(
938 (
939 key,
940 value,
941 )
942 for key, value in (
943 (
944 self.bind_names[bindparam],
945 bindparam.type._cached_bind_processor(self.dialect)
946 if not bindparam.type._is_tuple_type
947 else tuple(
948 elem_type._cached_bind_processor(self.dialect)
949 for elem_type in bindparam.type.types
950 ),
951 )
952 for bindparam in self.bind_names
953 )
954 if value is not None
955 )
956
957 def is_subquery(self):
958 return len(self.stack) > 1
959
960 @property
961 def sql_compiler(self):
962 return self
963
964 def construct_params(
965 self,
966 params=None,
967 _group_number=None,
968 _check=True,
969 extracted_parameters=None,
970 escape_names=True,
971 ):
972 """return a dictionary of bind parameter keys and values"""
973
974 has_escaped_names = escape_names and bool(self.escaped_bind_names)
975
976 if extracted_parameters:
977 # related the bound parameters collected in the original cache key
978 # to those collected in the incoming cache key. They will not have
979 # matching names but they will line up positionally in the same
980 # way. The parameters present in self.bind_names may be clones of
981 # these original cache key params in the case of DML but the .key
982 # will be guaranteed to match.
983 try:
984 orig_extracted = self.cache_key[1]
985 except TypeError as err:
986 util.raise_(
987 exc.CompileError(
988 "This compiled object has no original cache key; "
989 "can't pass extracted_parameters to construct_params"
990 ),
991 replace_context=err,
992 )
993
994 ckbm = self._cache_key_bind_match
995 resolved_extracted = {
996 bind: extracted
997 for b, extracted in zip(orig_extracted, extracted_parameters)
998 for bind in ckbm[b]
999 }
1000 else:
1001 resolved_extracted = None
1002
1003 if params:
1004 pd = {}
1005 for bindparam, name in self.bind_names.items():
1006 escaped_name = (
1007 self.escaped_bind_names.get(name, name)
1008 if has_escaped_names
1009 else name
1010 )
1011
1012 if bindparam.key in params:
1013 pd[escaped_name] = params[bindparam.key]
1014 elif name in params:
1015 pd[escaped_name] = params[name]
1016
1017 elif _check and bindparam.required:
1018 if _group_number:
1019 raise exc.InvalidRequestError(
1020 "A value is required for bind parameter %r, "
1021 "in parameter group %d"
1022 % (bindparam.key, _group_number),
1023 code="cd3x",
1024 )
1025 else:
1026 raise exc.InvalidRequestError(
1027 "A value is required for bind parameter %r"
1028 % bindparam.key,
1029 code="cd3x",
1030 )
1031 else:
1032 if resolved_extracted:
1033 value_param = resolved_extracted.get(
1034 bindparam, bindparam
1035 )
1036 else:
1037 value_param = bindparam
1038
1039 if bindparam.callable:
1040 pd[escaped_name] = value_param.effective_value
1041 else:
1042 pd[escaped_name] = value_param.value
1043 return pd
1044 else:
1045 pd = {}
1046 for bindparam, name in self.bind_names.items():
1047 escaped_name = (
1048 self.escaped_bind_names.get(name, name)
1049 if has_escaped_names
1050 else name
1051 )
1052
1053 if _check and bindparam.required:
1054 if _group_number:
1055 raise exc.InvalidRequestError(
1056 "A value is required for bind parameter %r, "
1057 "in parameter group %d"
1058 % (bindparam.key, _group_number),
1059 code="cd3x",
1060 )
1061 else:
1062 raise exc.InvalidRequestError(
1063 "A value is required for bind parameter %r"
1064 % bindparam.key,
1065 code="cd3x",
1066 )
1067
1068 if resolved_extracted:
1069 value_param = resolved_extracted.get(bindparam, bindparam)
1070 else:
1071 value_param = bindparam
1072
1073 if bindparam.callable:
1074 pd[escaped_name] = value_param.effective_value
1075 else:
1076 pd[escaped_name] = value_param.value
1077 return pd
1078
1079 @util.memoized_instancemethod
1080 def _get_set_input_sizes_lookup(
1081 self, include_types=None, exclude_types=None
1082 ):
1083 if not hasattr(self, "bind_names"):
1084 return None
1085
1086 dialect = self.dialect
1087 dbapi = self.dialect.dbapi
1088
1089 # _unwrapped_dialect_impl() is necessary so that we get the
1090 # correct dialect type for a custom TypeDecorator, or a Variant,
1091 # which is also a TypeDecorator. Special types like Interval,
1092 # that use TypeDecorator but also might be mapped directly
1093 # for a dialect impl, also subclass Emulated first which overrides
1094 # this behavior in those cases to behave like the default.
1095
1096 if include_types is None and exclude_types is None:
1097
1098 def _lookup_type(typ):
1099 dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
1100 return dbtype
1101
1102 else:
1103
1104 def _lookup_type(typ):
1105 # note we get dbtype from the possibly TypeDecorator-wrapped
1106 # dialect_impl, but the dialect_impl itself that we use for
1107 # include/exclude is the unwrapped version.
1108
1109 dialect_impl = typ._unwrapped_dialect_impl(dialect)
1110
1111 dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
1112
1113 if (
1114 dbtype is not None
1115 and (
1116 exclude_types is None
1117 or dbtype not in exclude_types
1118 and type(dialect_impl) not in exclude_types
1119 )
1120 and (
1121 include_types is None
1122 or dbtype in include_types
1123 or type(dialect_impl) in include_types
1124 )
1125 ):
1126 return dbtype
1127 else:
1128 return None
1129
1130 inputsizes = {}
1131 literal_execute_params = self.literal_execute_params
1132
1133 for bindparam in self.bind_names:
1134 if bindparam in literal_execute_params:
1135 continue
1136
1137 if bindparam.type._is_tuple_type:
1138 inputsizes[bindparam] = [
1139 _lookup_type(typ) for typ in bindparam.type.types
1140 ]
1141 else:
1142 inputsizes[bindparam] = _lookup_type(bindparam.type)
1143
1144 return inputsizes
1145
1146 @property
1147 def params(self):
1148 """Return the bind param dictionary embedded into this
1149 compiled object, for those values that are present.
1150
1151 .. seealso::
1152
1153 :ref:`faq_sql_expression_string` - includes a usage example for
1154 debugging use cases.
1155
1156 """
1157 return self.construct_params(_check=False)
1158
1159 def _process_parameters_for_postcompile(
1160 self, parameters=None, _populate_self=False
1161 ):
1162 """handle special post compile parameters.
1163
1164 These include:
1165
1166 * "expanding" parameters -typically IN tuples that are rendered
1167 on a per-parameter basis for an otherwise fixed SQL statement string.
1168
1169 * literal_binds compiled with the literal_execute flag. Used for
1170 things like SQL Server "TOP N" where the driver does not accommodate
1171 N as a bound parameter.
1172
1173 """
1174
1175 if parameters is None:
1176 parameters = self.construct_params(escape_names=False)
1177
1178 expanded_parameters = {}
1179 if self.positional:
1180 positiontup = []
1181 else:
1182 positiontup = None
1183
1184 processors = self._bind_processors
1185
1186 new_processors = {}
1187
1188 if self.positional and self._numeric_binds:
1189 # I'm not familiar with any DBAPI that uses 'numeric'.
1190 # strategy would likely be to make use of numbers greater than
1191 # the highest number present; then for expanding parameters,
1192 # append them to the end of the parameter list. that way
1193 # we avoid having to renumber all the existing parameters.
1194 raise NotImplementedError(
1195 "'post-compile' bind parameters are not supported with "
1196 "the 'numeric' paramstyle at this time."
1197 )
1198
1199 replacement_expressions = {}
1200 to_update_sets = {}
1201
1202 # notes:
1203 # *unescaped* parameter names in:
1204 # self.bind_names, self.binds, self._bind_processors
1205 #
1206 # *escaped* parameter names in:
1207 # construct_params(), replacement_expressions
1208
1209 for name in (
1210 self.positiontup if self.positional else self.bind_names.values()
1211 ):
1212 escaped_name = (
1213 self.escaped_bind_names.get(name, name)
1214 if self.escaped_bind_names
1215 else name
1216 )
1217
1218 parameter = self.binds[name]
1219 if parameter in self.literal_execute_params:
1220 if escaped_name not in replacement_expressions:
1221 replacement_expressions[
1222 escaped_name
1223 ] = self.render_literal_bindparam(
1224 parameter,
1225 render_literal_value=parameters.pop(escaped_name),
1226 )
1227 continue
1228
1229 if parameter in self.post_compile_params:
1230 if escaped_name in replacement_expressions:
1231 to_update = to_update_sets[escaped_name]
1232 else:
1233 # we are removing the parameter from parameters
1234 # because it is a list value, which is not expected by
1235 # TypeEngine objects that would otherwise be asked to
1236 # process it. the single name is being replaced with
1237 # individual numbered parameters for each value in the
1238 # param.
1239 #
1240 # note we are also inserting *escaped* parameter names
1241 # into the given dictionary. default dialect will
1242 # use these param names directly as they will not be
1243 # in the escaped_bind_names dictionary.
1244 values = parameters.pop(name)
1245
1246 leep = self._literal_execute_expanding_parameter
1247 to_update, replacement_expr = leep(
1248 escaped_name, parameter, values
1249 )
1250
1251 to_update_sets[escaped_name] = to_update
1252 replacement_expressions[escaped_name] = replacement_expr
1253
1254 if not parameter.literal_execute:
1255 parameters.update(to_update)
1256 if parameter.type._is_tuple_type:
1257 new_processors.update(
1258 (
1259 "%s_%s_%s" % (name, i, j),
1260 processors[name][j - 1],
1261 )
1262 for i, tuple_element in enumerate(values, 1)
1263 for j, value in enumerate(tuple_element, 1)
1264 if name in processors
1265 and processors[name][j - 1] is not None
1266 )
1267 else:
1268 new_processors.update(
1269 (key, processors[name])
1270 for key, value in to_update
1271 if name in processors
1272 )
1273 if self.positional:
1274 positiontup.extend(name for name, value in to_update)
1275 expanded_parameters[name] = [
1276 expand_key for expand_key, value in to_update
1277 ]
1278 elif self.positional:
1279 positiontup.append(name)
1280
1281 def process_expanding(m):
1282 key = m.group(1)
1283 expr = replacement_expressions[key]
1284
1285 # if POSTCOMPILE included a bind_expression, render that
1286 # around each element
1287 if m.group(2):
1288 tok = m.group(2).split("~~")
1289 be_left, be_right = tok[1], tok[3]
1290 expr = ", ".join(
1291 "%s%s%s" % (be_left, exp, be_right)
1292 for exp in expr.split(", ")
1293 )
1294 return expr
1295
1296 statement = re.sub(
1297 self._post_compile_pattern,
1298 process_expanding,
1299 self.string,
1300 )
1301
1302 expanded_state = ExpandedState(
1303 statement,
1304 parameters,
1305 new_processors,
1306 positiontup,
1307 expanded_parameters,
1308 )
1309
1310 if _populate_self:
1311 # this is for the "render_postcompile" flag, which is not
1312 # otherwise used internally and is for end-user debugging and
1313 # special use cases.
1314 self.string = expanded_state.statement
1315 self._bind_processors.update(expanded_state.processors)
1316 self.positiontup = expanded_state.positiontup
1317 self.post_compile_params = frozenset()
1318 for key in expanded_state.parameter_expansion:
1319 bind = self.binds.pop(key)
1320 self.bind_names.pop(bind)
1321 for value, expanded_key in zip(
1322 bind.value, expanded_state.parameter_expansion[key]
1323 ):
1324 self.binds[expanded_key] = new_param = bind._with_value(
1325 value
1326 )
1327 self.bind_names[new_param] = expanded_key
1328
1329 return expanded_state
1330
1331 @util.preload_module("sqlalchemy.engine.cursor")
1332 def _create_result_map(self):
1333 """utility method used for unit tests only."""
1334 cursor = util.preloaded.engine_cursor
1335 return cursor.CursorResultMetaData._create_description_match_map(
1336 self._result_columns
1337 )
1338
1339 @util.memoized_property
1340 def _within_exec_param_key_getter(self):
1341 getter = self._key_getters_for_crud_column[2]
1342 return getter
1343
1344 @util.memoized_property
1345 @util.preload_module("sqlalchemy.engine.result")
1346 def _inserted_primary_key_from_lastrowid_getter(self):
1347 result = util.preloaded.engine_result
1348
1349 param_key_getter = self._within_exec_param_key_getter
1350 table = self.statement.table
1351
1352 getters = [
1353 (operator.methodcaller("get", param_key_getter(col), None), col)
1354 for col in table.primary_key
1355 ]
1356
1357 autoinc_col = table._autoincrement_column
1358 if autoinc_col is not None:
1359 # apply type post processors to the lastrowid
1360 proc = autoinc_col.type._cached_result_processor(
1361 self.dialect, None
1362 )
1363 else:
1364 proc = None
1365
1366 row_fn = result.result_tuple([col.key for col in table.primary_key])
1367
1368 def get(lastrowid, parameters):
1369 """given cursor.lastrowid value and the parameters used for INSERT,
1370 return a "row" that represents the primary key, either by
1371 using the "lastrowid" or by extracting values from the parameters
1372 that were sent along with the INSERT.
1373
1374 """
1375 if proc is not None:
1376 lastrowid = proc(lastrowid)
1377
1378 if lastrowid is None:
1379 return row_fn(getter(parameters) for getter, col in getters)
1380 else:
1381 return row_fn(
1382 lastrowid if col is autoinc_col else getter(parameters)
1383 for getter, col in getters
1384 )
1385
1386 return get
1387
1388 @util.memoized_property
1389 @util.preload_module("sqlalchemy.engine.result")
1390 def _inserted_primary_key_from_returning_getter(self):
1391 result = util.preloaded.engine_result
1392
1393 param_key_getter = self._within_exec_param_key_getter
1394 table = self.statement.table
1395
1396 ret = {col: idx for idx, col in enumerate(self.returning)}
1397
1398 getters = [
1399 (operator.itemgetter(ret[col]), True)
1400 if col in ret
1401 else (
1402 operator.methodcaller("get", param_key_getter(col), None),
1403 False,
1404 )
1405 for col in table.primary_key
1406 ]
1407
1408 row_fn = result.result_tuple([col.key for col in table.primary_key])
1409
1410 def get(row, parameters):
1411 return row_fn(
1412 getter(row) if use_row else getter(parameters)
1413 for getter, use_row in getters
1414 )
1415
1416 return get
1417
1418 def default_from(self):
1419 """Called when a SELECT statement has no froms, and no FROM clause is
1420 to be appended.
1421
1422 Gives Oracle a chance to tack on a ``FROM DUAL`` to the string output.
1423
1424 """
1425 return ""
1426
1427 def visit_grouping(self, grouping, asfrom=False, **kwargs):
1428 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
1429
1430 def visit_select_statement_grouping(self, grouping, **kwargs):
1431 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
1432
1433 def visit_label_reference(
1434 self, element, within_columns_clause=False, **kwargs
1435 ):
1436 if self.stack and self.dialect.supports_simple_order_by_label:
1437 compile_state = self.stack[-1]["compile_state"]
1438
1439 (
1440 with_cols,
1441 only_froms,
1442 only_cols,
1443 ) = compile_state._label_resolve_dict
1444 if within_columns_clause:
1445 resolve_dict = only_froms
1446 else:
1447 resolve_dict = only_cols
1448
1449 # this can be None in the case that a _label_reference()
1450 # were subject to a replacement operation, in which case
1451 # the replacement of the Label element may have changed
1452 # to something else like a ColumnClause expression.
1453 order_by_elem = element.element._order_by_label_element
1454
1455 if (
1456 order_by_elem is not None
1457 and order_by_elem.name in resolve_dict
1458 and order_by_elem.shares_lineage(
1459 resolve_dict[order_by_elem.name]
1460 )
1461 ):
1462 kwargs[
1463 "render_label_as_label"
1464 ] = element.element._order_by_label_element
1465 return self.process(
1466 element.element,
1467 within_columns_clause=within_columns_clause,
1468 **kwargs
1469 )
1470
1471 def visit_textual_label_reference(
1472 self, element, within_columns_clause=False, **kwargs
1473 ):
1474 if not self.stack:
1475 # compiling the element outside of the context of a SELECT
1476 return self.process(element._text_clause)
1477
1478 compile_state = self.stack[-1]["compile_state"]
1479 with_cols, only_froms, only_cols = compile_state._label_resolve_dict
1480 try:
1481 if within_columns_clause:
1482 col = only_froms[element.element]
1483 else:
1484 col = with_cols[element.element]
1485 except KeyError as err:
1486 coercions._no_text_coercion(
1487 element.element,
1488 extra=(
1489 "Can't resolve label reference for ORDER BY / "
1490 "GROUP BY / DISTINCT etc."
1491 ),
1492 exc_cls=exc.CompileError,
1493 err=err,
1494 )
1495 else:
1496 kwargs["render_label_as_label"] = col
1497 return self.process(
1498 col, within_columns_clause=within_columns_clause, **kwargs
1499 )
1500
1501 def visit_label(
1502 self,
1503 label,
1504 add_to_result_map=None,
1505 within_label_clause=False,
1506 within_columns_clause=False,
1507 render_label_as_label=None,
1508 result_map_targets=(),
1509 **kw
1510 ):
1511 # only render labels within the columns clause
1512 # or ORDER BY clause of a select. dialect-specific compilers
1513 # can modify this behavior.
1514 render_label_with_as = (
1515 within_columns_clause and not within_label_clause
1516 )
1517 render_label_only = render_label_as_label is label
1518
1519 if render_label_only or render_label_with_as:
1520 if isinstance(label.name, elements._truncated_label):
1521 labelname = self._truncated_identifier("colident", label.name)
1522 else:
1523 labelname = label.name
1524
1525 if render_label_with_as:
1526 if add_to_result_map is not None:
1527 add_to_result_map(
1528 labelname,
1529 label.name,
1530 (label, labelname) + label._alt_names + result_map_targets,
1531 label.type,
1532 )
1533 return (
1534 label.element._compiler_dispatch(
1535 self,
1536 within_columns_clause=True,
1537 within_label_clause=True,
1538 **kw
1539 )
1540 + OPERATORS[operators.as_]
1541 + self.preparer.format_label(label, labelname)
1542 )
1543 elif render_label_only:
1544 return self.preparer.format_label(label, labelname)
1545 else:
1546 return label.element._compiler_dispatch(
1547 self, within_columns_clause=False, **kw
1548 )
1549
1550 def _fallback_column_name(self, column):
1551 raise exc.CompileError(
1552 "Cannot compile Column object until " "its 'name' is assigned."
1553 )
1554
1555 def visit_lambda_element(self, element, **kw):
1556 sql_element = element._resolved
1557 return self.process(sql_element, **kw)
1558
1559 def visit_column(
1560 self,
1561 column,
1562 add_to_result_map=None,
1563 include_table=True,
1564 result_map_targets=(),
1565 **kwargs
1566 ):
1567 name = orig_name = column.name
1568 if name is None:
1569 name = self._fallback_column_name(column)
1570
1571 is_literal = column.is_literal
1572 if not is_literal and isinstance(name, elements._truncated_label):
1573 name = self._truncated_identifier("colident", name)
1574
1575 if add_to_result_map is not None:
1576 targets = (column, name, column.key) + result_map_targets
1577 if column._tq_label:
1578 targets += (column._tq_label,)
1579
1580 add_to_result_map(name, orig_name, targets, column.type)
1581
1582 if is_literal:
1583 # note we are not currently accommodating for
1584 # literal_column(quoted_name('ident', True)) here
1585 name = self.escape_literal_column(name)
1586 else:
1587 name = self.preparer.quote(name)
1588 table = column.table
1589 if table is None or not include_table or not table.named_with_column:
1590 return name
1591 else:
1592 effective_schema = self.preparer.schema_for_object(table)
1593
1594 if effective_schema:
1595 schema_prefix = (
1596 self.preparer.quote_schema(effective_schema) + "."
1597 )
1598 else:
1599 schema_prefix = ""
1600 tablename = table.name
1601 if isinstance(tablename, elements._truncated_label):
1602 tablename = self._truncated_identifier("alias", tablename)
1603
1604 return schema_prefix + self.preparer.quote(tablename) + "." + name
1605
1606 def visit_collation(self, element, **kw):
1607 return self.preparer.format_collation(element.collation)
1608
1609 def visit_fromclause(self, fromclause, **kwargs):
1610 return fromclause.name
1611
1612 def visit_index(self, index, **kwargs):
1613 return index.name
1614
1615 def visit_typeclause(self, typeclause, **kw):
1616 kw["type_expression"] = typeclause
1617 kw["identifier_preparer"] = self.preparer
1618 return self.dialect.type_compiler.process(typeclause.type, **kw)
1619
1620 def post_process_text(self, text):
1621 if self.preparer._double_percents:
1622 text = text.replace("%", "%%")
1623 return text
1624
1625 def escape_literal_column(self, text):
1626 if self.preparer._double_percents:
1627 text = text.replace("%", "%%")
1628 return text
1629
1630 def visit_textclause(self, textclause, add_to_result_map=None, **kw):
1631 def do_bindparam(m):
1632 name = m.group(1)
1633 if name in textclause._bindparams:
1634 return self.process(textclause._bindparams[name], **kw)
1635 else:
1636 return self.bindparam_string(name, **kw)
1637
1638 if not self.stack:
1639 self.isplaintext = True
1640
1641 if add_to_result_map:
1642 # text() object is present in the columns clause of a
1643 # select(). Add a no-name entry to the result map so that
1644 # row[text()] produces a result
1645 add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
1646
1647 # un-escape any \:params
1648 return BIND_PARAMS_ESC.sub(
1649 lambda m: m.group(1),
1650 BIND_PARAMS.sub(
1651 do_bindparam, self.post_process_text(textclause.text)
1652 ),
1653 )
1654
1655 def visit_textual_select(
1656 self, taf, compound_index=None, asfrom=False, **kw
1657 ):
1658
1659 toplevel = not self.stack
1660 entry = self._default_stack_entry if toplevel else self.stack[-1]
1661
1662 new_entry = {
1663 "correlate_froms": set(),
1664 "asfrom_froms": set(),
1665 "selectable": taf,
1666 }
1667 self.stack.append(new_entry)
1668
1669 if taf._independent_ctes:
1670 for cte in taf._independent_ctes:
1671 cte._compiler_dispatch(self, **kw)
1672
1673 populate_result_map = (
1674 toplevel
1675 or (
1676 compound_index == 0
1677 and entry.get("need_result_map_for_compound", False)
1678 )
1679 or entry.get("need_result_map_for_nested", False)
1680 )
1681
1682 if populate_result_map:
1683 self._ordered_columns = (
1684 self._textual_ordered_columns
1685 ) = taf.positional
1686
1687 # enable looser result column matching when the SQL text links to
1688 # Column objects by name only
1689 self._loose_column_name_matching = not taf.positional and bool(
1690 taf.column_args
1691 )
1692
1693 for c in taf.column_args:
1694 self.process(
1695 c,
1696 within_columns_clause=True,
1697 add_to_result_map=self._add_to_result_map,
1698 )
1699
1700 text = self.process(taf.element, **kw)
1701 if self.ctes:
1702 nesting_level = len(self.stack) if not toplevel else None
1703 text = (
1704 self._render_cte_clause(
1705 nesting_level=nesting_level,
1706 visiting_cte=kw.get("visiting_cte"),
1707 )
1708 + text
1709 )
1710
1711 self.stack.pop(-1)
1712
1713 return text
1714
1715 def visit_null(self, expr, **kw):
1716 return "NULL"
1717
1718 def visit_true(self, expr, **kw):
1719 if self.dialect.supports_native_boolean:
1720 return "true"
1721 else:
1722 return "1"
1723
1724 def visit_false(self, expr, **kw):
1725 if self.dialect.supports_native_boolean:
1726 return "false"
1727 else:
1728 return "0"
1729
1730 def _generate_delimited_list(self, elements, separator, **kw):
1731 return separator.join(
1732 s
1733 for s in (c._compiler_dispatch(self, **kw) for c in elements)
1734 if s
1735 )
1736
1737 def _generate_delimited_and_list(self, clauses, **kw):
1738
1739 lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean(
1740 operators.and_,
1741 elements.True_._singleton,
1742 elements.False_._singleton,
1743 clauses,
1744 )
1745 if lcc == 1:
1746 return clauses[0]._compiler_dispatch(self, **kw)
1747 else:
1748 separator = OPERATORS[operators.and_]
1749 return separator.join(
1750 s
1751 for s in (c._compiler_dispatch(self, **kw) for c in clauses)
1752 if s
1753 )
1754
1755 def visit_tuple(self, clauselist, **kw):
1756 return "(%s)" % self.visit_clauselist(clauselist, **kw)
1757
1758 def visit_clauselist(self, clauselist, **kw):
1759 sep = clauselist.operator
1760 if sep is None:
1761 sep = " "
1762 else:
1763 sep = OPERATORS[clauselist.operator]
1764
1765 return self._generate_delimited_list(clauselist.clauses, sep, **kw)
1766
1767 def visit_case(self, clause, **kwargs):
1768 x = "CASE "
1769 if clause.value is not None:
1770 x += clause.value._compiler_dispatch(self, **kwargs) + " "
1771 for cond, result in clause.whens:
1772 x += (
1773 "WHEN "
1774 + cond._compiler_dispatch(self, **kwargs)
1775 + " THEN "
1776 + result._compiler_dispatch(self, **kwargs)
1777 + " "
1778 )
1779 if clause.else_ is not None:
1780 x += (
1781 "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
1782 )
1783 x += "END"
1784 return x
1785
1786 def visit_type_coerce(self, type_coerce, **kw):
1787 return type_coerce.typed_expression._compiler_dispatch(self, **kw)
1788
1789 def visit_cast(self, cast, **kwargs):
1790 return "CAST(%s AS %s)" % (
1791 cast.clause._compiler_dispatch(self, **kwargs),
1792 cast.typeclause._compiler_dispatch(self, **kwargs),
1793 )
1794
1795 def _format_frame_clause(self, range_, **kw):
1796
1797 return "%s AND %s" % (
1798 "UNBOUNDED PRECEDING"
1799 if range_[0] is elements.RANGE_UNBOUNDED
1800 else "CURRENT ROW"
1801 if range_[0] is elements.RANGE_CURRENT
1802 else "%s PRECEDING"
1803 % (self.process(elements.literal(abs(range_[0])), **kw),)
1804 if range_[0] < 0
1805 else "%s FOLLOWING"
1806 % (self.process(elements.literal(range_[0]), **kw),),
1807 "UNBOUNDED FOLLOWING"
1808 if range_[1] is elements.RANGE_UNBOUNDED
1809 else "CURRENT ROW"
1810 if range_[1] is elements.RANGE_CURRENT
1811 else "%s PRECEDING"
1812 % (self.process(elements.literal(abs(range_[1])), **kw),)
1813 if range_[1] < 0
1814 else "%s FOLLOWING"
1815 % (self.process(elements.literal(range_[1]), **kw),),
1816 )
1817
1818 def visit_over(self, over, **kwargs):
1819 text = over.element._compiler_dispatch(self, **kwargs)
1820 if over.range_:
1821 range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
1822 over.range_, **kwargs
1823 )
1824 elif over.rows:
1825 range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
1826 over.rows, **kwargs
1827 )
1828 else:
1829 range_ = None
1830
1831 return "%s OVER (%s)" % (
1832 text,
1833 " ".join(
1834 [
1835 "%s BY %s"
1836 % (word, clause._compiler_dispatch(self, **kwargs))
1837 for word, clause in (
1838 ("PARTITION", over.partition_by),
1839 ("ORDER", over.order_by),
1840 )
1841 if clause is not None and len(clause)
1842 ]
1843 + ([range_] if range_ else [])
1844 ),
1845 )
1846
1847 def visit_withingroup(self, withingroup, **kwargs):
1848 return "%s WITHIN GROUP (ORDER BY %s)" % (
1849 withingroup.element._compiler_dispatch(self, **kwargs),
1850 withingroup.order_by._compiler_dispatch(self, **kwargs),
1851 )
1852
1853 def visit_funcfilter(self, funcfilter, **kwargs):
1854 return "%s FILTER (WHERE %s)" % (
1855 funcfilter.func._compiler_dispatch(self, **kwargs),
1856 funcfilter.criterion._compiler_dispatch(self, **kwargs),
1857 )
1858
1859 def visit_extract(self, extract, **kwargs):
1860 field = self.extract_map.get(extract.field, extract.field)
1861 return "EXTRACT(%s FROM %s)" % (
1862 field,
1863 extract.expr._compiler_dispatch(self, **kwargs),
1864 )
1865
1866 def visit_scalar_function_column(self, element, **kw):
1867 compiled_fn = self.visit_function(element.fn, **kw)
1868 compiled_col = self.visit_column(element, **kw)
1869 return "(%s).%s" % (compiled_fn, compiled_col)
1870
1871 def visit_function(self, func, add_to_result_map=None, **kwargs):
1872 if add_to_result_map is not None:
1873 add_to_result_map(func.name, func.name, (), func.type)
1874
1875 disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
1876 if disp:
1877 text = disp(func, **kwargs)
1878 else:
1879 name = FUNCTIONS.get(func._deannotate().__class__, None)
1880 if name:
1881 if func._has_args:
1882 name += "%(expr)s"
1883 else:
1884 name = func.name
1885 name = (
1886 self.preparer.quote(name)
1887 if self.preparer._requires_quotes_illegal_chars(name)
1888 or isinstance(name, elements.quoted_name)
1889 else name
1890 )
1891 name = name + "%(expr)s"
1892 text = ".".join(
1893 [
1894 (
1895 self.preparer.quote(tok)
1896 if self.preparer._requires_quotes_illegal_chars(tok)
1897 or isinstance(name, elements.quoted_name)
1898 else tok
1899 )
1900 for tok in func.packagenames
1901 ]
1902 + [name]
1903 ) % {"expr": self.function_argspec(func, **kwargs)}
1904
1905 if func._with_ordinality:
1906 text += " WITH ORDINALITY"
1907 return text
1908
1909 def visit_next_value_func(self, next_value, **kw):
1910 return self.visit_sequence(next_value.sequence)
1911
1912 def visit_sequence(self, sequence, **kw):
1913 raise NotImplementedError(
1914 "Dialect '%s' does not support sequence increments."
1915 % self.dialect.name
1916 )
1917
1918 def function_argspec(self, func, **kwargs):
1919 return func.clause_expr._compiler_dispatch(self, **kwargs)
1920
1921 def visit_compound_select(
1922 self, cs, asfrom=False, compound_index=None, **kwargs
1923 ):
1924 toplevel = not self.stack
1925
1926 compile_state = cs._compile_state_factory(cs, self, **kwargs)
1927
1928 if toplevel and not self.compile_state:
1929 self.compile_state = compile_state
1930
1931 compound_stmt = compile_state.statement
1932
1933 entry = self._default_stack_entry if toplevel else self.stack[-1]
1934 need_result_map = toplevel or (
1935 not compound_index
1936 and entry.get("need_result_map_for_compound", False)
1937 )
1938
1939 # indicates there is already a CompoundSelect in play
1940 if compound_index == 0:
1941 entry["select_0"] = cs
1942
1943 self.stack.append(
1944 {
1945 "correlate_froms": entry["correlate_froms"],
1946 "asfrom_froms": entry["asfrom_froms"],
1947 "selectable": cs,
1948 "compile_state": compile_state,
1949 "need_result_map_for_compound": need_result_map,
1950 }
1951 )
1952
1953 if compound_stmt._independent_ctes:
1954 for cte in compound_stmt._independent_ctes:
1955 cte._compiler_dispatch(self, **kwargs)
1956
1957 keyword = self.compound_keywords.get(cs.keyword)
1958
1959 text = (" " + keyword + " ").join(
1960 (
1961 c._compiler_dispatch(
1962 self, asfrom=asfrom, compound_index=i, **kwargs
1963 )
1964 for i, c in enumerate(cs.selects)
1965 )
1966 )
1967
1968 kwargs["include_table"] = False
1969 text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
1970 text += self.order_by_clause(cs, **kwargs)
1971 if cs._has_row_limiting_clause:
1972 text += self._row_limit_clause(cs, **kwargs)
1973
1974 if self.ctes:
1975 nesting_level = len(self.stack) if not toplevel else None
1976 text = (
1977 self._render_cte_clause(
1978 nesting_level=nesting_level,
1979 include_following_stack=True,
1980 visiting_cte=kwargs.get("visiting_cte"),
1981 )
1982 + text
1983 )
1984
1985 self.stack.pop(-1)
1986 return text
1987
1988 def _row_limit_clause(self, cs, **kwargs):
1989 if cs._fetch_clause is not None:
1990 return self.fetch_clause(cs, **kwargs)
1991 else:
1992 return self.limit_clause(cs, **kwargs)
1993
1994 def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
1995 attrname = "visit_%s_%s%s" % (
1996 operator_.__name__,
1997 qualifier1,
1998 "_" + qualifier2 if qualifier2 else "",
1999 )
2000 return getattr(self, attrname, None)
2001
2002 def visit_unary(
2003 self, unary, add_to_result_map=None, result_map_targets=(), **kw
2004 ):
2005
2006 if add_to_result_map is not None:
2007 result_map_targets += (unary,)
2008 kw["add_to_result_map"] = add_to_result_map
2009 kw["result_map_targets"] = result_map_targets
2010
2011 if unary.operator:
2012 if unary.modifier:
2013 raise exc.CompileError(
2014 "Unary expression does not support operator "
2015 "and modifier simultaneously"
2016 )
2017 disp = self._get_operator_dispatch(
2018 unary.operator, "unary", "operator"
2019 )
2020 if disp:
2021 return disp(unary, unary.operator, **kw)
2022 else:
2023 return self._generate_generic_unary_operator(
2024 unary, OPERATORS[unary.operator], **kw
2025 )
2026 elif unary.modifier:
2027 disp = self._get_operator_dispatch(
2028 unary.modifier, "unary", "modifier"
2029 )
2030 if disp:
2031 return disp(unary, unary.modifier, **kw)
2032 else:
2033 return self._generate_generic_unary_modifier(
2034 unary, OPERATORS[unary.modifier], **kw
2035 )
2036 else:
2037 raise exc.CompileError(
2038 "Unary expression has no operator or modifier"
2039 )
2040
2041 def visit_is_true_unary_operator(self, element, operator, **kw):
2042 if (
2043 element._is_implicitly_boolean
2044 or self.dialect.supports_native_boolean
2045 ):
2046 return self.process(element.element, **kw)
2047 else:
2048 return "%s = 1" % self.process(element.element, **kw)
2049
2050 def visit_is_false_unary_operator(self, element, operator, **kw):
2051 if (
2052 element._is_implicitly_boolean
2053 or self.dialect.supports_native_boolean
2054 ):
2055 return "NOT %s" % self.process(element.element, **kw)
2056 else:
2057 return "%s = 0" % self.process(element.element, **kw)
2058
2059 def visit_not_match_op_binary(self, binary, operator, **kw):
2060 return "NOT %s" % self.visit_binary(
2061 binary, override_operator=operators.match_op
2062 )
2063
2064 def visit_not_in_op_binary(self, binary, operator, **kw):
2065 # The brackets are required in the NOT IN operation because the empty
2066 # case is handled using the form "(col NOT IN (null) OR 1 = 1)".
2067 # The presence of the OR makes the brackets required.
2068 return "(%s)" % self._generate_generic_binary(
2069 binary, OPERATORS[operator], **kw
2070 )
2071
2072 def visit_empty_set_op_expr(self, type_, expand_op):
2073 if expand_op is operators.not_in_op:
2074 if len(type_) > 1:
2075 return "(%s)) OR (1 = 1" % (
2076 ", ".join("NULL" for element in type_)
2077 )
2078 else:
2079 return "NULL) OR (1 = 1"
2080 elif expand_op is operators.in_op:
2081 if len(type_) > 1:
2082 return "(%s)) AND (1 != 1" % (
2083 ", ".join("NULL" for element in type_)
2084 )
2085 else:
2086 return "NULL) AND (1 != 1"
2087 else:
2088 return self.visit_empty_set_expr(type_)
2089
2090 def visit_empty_set_expr(self, element_types):
2091 raise NotImplementedError(
2092 "Dialect '%s' does not support empty set expression."
2093 % self.dialect.name
2094 )
2095
2096 def _literal_execute_expanding_parameter_literal_binds(
2097 self, parameter, values, bind_expression_template=None
2098 ):
2099
2100 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
2101
2102 if not values:
2103 # empty IN expression. note we don't need to use
2104 # bind_expression_template here because there are no
2105 # expressions to render.
2106
2107 if typ_dialect_impl._is_tuple_type:
2108 replacement_expression = (
2109 "VALUES " if self.dialect.tuple_in_values else ""
2110 ) + self.visit_empty_set_op_expr(
2111 parameter.type.types, parameter.expand_op
2112 )
2113
2114 else:
2115 replacement_expression = self.visit_empty_set_op_expr(
2116 [parameter.type], parameter.expand_op
2117 )
2118
2119 elif typ_dialect_impl._is_tuple_type or (
2120 typ_dialect_impl._isnull
2121 and isinstance(values[0], util.collections_abc.Sequence)
2122 and not isinstance(
2123 values[0], util.string_types + util.binary_types
2124 )
2125 ):
2126
2127 if typ_dialect_impl._has_bind_expression:
2128 raise NotImplementedError(
2129 "bind_expression() on TupleType not supported with "
2130 "literal_binds"
2131 )
2132
2133 replacement_expression = (
2134 "VALUES " if self.dialect.tuple_in_values else ""
2135 ) + ", ".join(
2136 "(%s)"
2137 % (
2138 ", ".join(
2139 self.render_literal_value(value, param_type)
2140 for value, param_type in zip(
2141 tuple_element, parameter.type.types
2142 )
2143 )
2144 )
2145 for i, tuple_element in enumerate(values)
2146 )
2147 else:
2148 if bind_expression_template:
2149 post_compile_pattern = self._post_compile_pattern
2150 m = post_compile_pattern.search(bind_expression_template)
2151 assert m and m.group(
2152 2
2153 ), "unexpected format for expanding parameter"
2154
2155 tok = m.group(2).split("~~")
2156 be_left, be_right = tok[1], tok[3]
2157 replacement_expression = ", ".join(
2158 "%s%s%s"
2159 % (
2160 be_left,
2161 self.render_literal_value(value, parameter.type),
2162 be_right,
2163 )
2164 for value in values
2165 )
2166 else:
2167 replacement_expression = ", ".join(
2168 self.render_literal_value(value, parameter.type)
2169 for value in values
2170 )
2171
2172 return (), replacement_expression
2173
2174 def _literal_execute_expanding_parameter(self, name, parameter, values):
2175
2176 if parameter.literal_execute:
2177 return self._literal_execute_expanding_parameter_literal_binds(
2178 parameter, values
2179 )
2180
2181 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
2182
2183 if not values:
2184 to_update = []
2185 if typ_dialect_impl._is_tuple_type:
2186
2187 replacement_expression = self.visit_empty_set_op_expr(
2188 parameter.type.types, parameter.expand_op
2189 )
2190 else:
2191 replacement_expression = self.visit_empty_set_op_expr(
2192 [parameter.type], parameter.expand_op
2193 )
2194
2195 elif typ_dialect_impl._is_tuple_type or (
2196 typ_dialect_impl._isnull
2197 and isinstance(values[0], util.collections_abc.Sequence)
2198 and not isinstance(
2199 values[0], util.string_types + util.binary_types
2200 )
2201 ):
2202 assert not typ_dialect_impl._is_array
2203 to_update = [
2204 ("%s_%s_%s" % (name, i, j), value)
2205 for i, tuple_element in enumerate(values, 1)
2206 for j, value in enumerate(tuple_element, 1)
2207 ]
2208 replacement_expression = (
2209 "VALUES " if self.dialect.tuple_in_values else ""
2210 ) + ", ".join(
2211 "(%s)"
2212 % (
2213 ", ".join(
2214 self.bindtemplate
2215 % {"name": to_update[i * len(tuple_element) + j][0]}
2216 for j, value in enumerate(tuple_element)
2217 )
2218 )
2219 for i, tuple_element in enumerate(values)
2220 )
2221 else:
2222 to_update = [
2223 ("%s_%s" % (name, i), value)
2224 for i, value in enumerate(values, 1)
2225 ]
2226 replacement_expression = ", ".join(
2227 self.bindtemplate % {"name": key} for key, value in to_update
2228 )
2229
2230 return to_update, replacement_expression
2231
2232 def visit_binary(
2233 self,
2234 binary,
2235 override_operator=None,
2236 eager_grouping=False,
2237 from_linter=None,
2238 lateral_from_linter=None,
2239 **kw
2240 ):
2241 if from_linter and operators.is_comparison(binary.operator):
2242 if lateral_from_linter is not None:
2243 enclosing_lateral = kw["enclosing_lateral"]
2244 lateral_from_linter.edges.update(
2245 itertools.product(
2246 binary.left._from_objects + [enclosing_lateral],
2247 binary.right._from_objects + [enclosing_lateral],
2248 )
2249 )
2250 else:
2251 from_linter.edges.update(
2252 itertools.product(
2253 binary.left._from_objects, binary.right._from_objects
2254 )
2255 )
2256
2257 # don't allow "? = ?" to render
2258 if (
2259 self.ansi_bind_rules
2260 and isinstance(binary.left, elements.BindParameter)
2261 and isinstance(binary.right, elements.BindParameter)
2262 ):
2263 kw["literal_execute"] = True
2264
2265 operator_ = override_operator or binary.operator
2266 disp = self._get_operator_dispatch(operator_, "binary", None)
2267 if disp:
2268 return disp(binary, operator_, **kw)
2269 else:
2270 try:
2271 opstring = OPERATORS[operator_]
2272 except KeyError as err:
2273 util.raise_(
2274 exc.UnsupportedCompilationError(self, operator_),
2275 replace_context=err,
2276 )
2277 else:
2278 return self._generate_generic_binary(
2279 binary,
2280 opstring,
2281 from_linter=from_linter,
2282 lateral_from_linter=lateral_from_linter,
2283 **kw
2284 )
2285
2286 def visit_function_as_comparison_op_binary(self, element, operator, **kw):
2287 return self.process(element.sql_function, **kw)
2288
2289 def visit_mod_binary(self, binary, operator, **kw):
2290 if self.preparer._double_percents:
2291 return (
2292 self.process(binary.left, **kw)
2293 + " %% "
2294 + self.process(binary.right, **kw)
2295 )
2296 else:
2297 return (
2298 self.process(binary.left, **kw)
2299 + " % "
2300 + self.process(binary.right, **kw)
2301 )
2302
2303 def visit_custom_op_binary(self, element, operator, **kw):
2304 kw["eager_grouping"] = operator.eager_grouping
2305 return self._generate_generic_binary(
2306 element,
2307 " " + self.escape_literal_column(operator.opstring) + " ",
2308 **kw
2309 )
2310
2311 def visit_custom_op_unary_operator(self, element, operator, **kw):
2312 return self._generate_generic_unary_operator(
2313 element, self.escape_literal_column(operator.opstring) + " ", **kw
2314 )
2315
2316 def visit_custom_op_unary_modifier(self, element, operator, **kw):
2317 return self._generate_generic_unary_modifier(
2318 element, " " + self.escape_literal_column(operator.opstring), **kw
2319 )
2320
2321 def _generate_generic_binary(
2322 self, binary, opstring, eager_grouping=False, **kw
2323 ):
2324
2325 _in_binary = kw.get("_in_binary", False)
2326
2327 kw["_in_binary"] = True
2328 kw["_binary_op"] = binary.operator
2329 text = (
2330 binary.left._compiler_dispatch(
2331 self, eager_grouping=eager_grouping, **kw
2332 )
2333 + opstring
2334 + binary.right._compiler_dispatch(
2335 self, eager_grouping=eager_grouping, **kw
2336 )
2337 )
2338
2339 if _in_binary and eager_grouping:
2340 text = "(%s)" % text
2341 return text
2342
2343 def _generate_generic_unary_operator(self, unary, opstring, **kw):
2344 return opstring + unary.element._compiler_dispatch(self, **kw)
2345
2346 def _generate_generic_unary_modifier(self, unary, opstring, **kw):
2347 return unary.element._compiler_dispatch(self, **kw) + opstring
2348
2349 @util.memoized_property
2350 def _like_percent_literal(self):
2351 return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
2352
2353 def visit_contains_op_binary(self, binary, operator, **kw):
2354 binary = binary._clone()
2355 percent = self._like_percent_literal
2356 binary.right = percent.concat(binary.right).concat(percent)
2357 return self.visit_like_op_binary(binary, operator, **kw)
2358
2359 def visit_not_contains_op_binary(self, binary, operator, **kw):
2360 binary = binary._clone()
2361 percent = self._like_percent_literal
2362 binary.right = percent.concat(binary.right).concat(percent)
2363 return self.visit_not_like_op_binary(binary, operator, **kw)
2364
2365 def visit_startswith_op_binary(self, binary, operator, **kw):
2366 binary = binary._clone()
2367 percent = self._like_percent_literal
2368 binary.right = percent._rconcat(binary.right)
2369 return self.visit_like_op_binary(binary, operator, **kw)
2370
2371 def visit_not_startswith_op_binary(self, binary, operator, **kw):
2372 binary = binary._clone()
2373 percent = self._like_percent_literal
2374 binary.right = percent._rconcat(binary.right)
2375 return self.visit_not_like_op_binary(binary, operator, **kw)
2376
2377 def visit_endswith_op_binary(self, binary, operator, **kw):
2378 binary = binary._clone()
2379 percent = self._like_percent_literal
2380 binary.right = percent.concat(binary.right)
2381 return self.visit_like_op_binary(binary, operator, **kw)
2382
2383 def visit_not_endswith_op_binary(self, binary, operator, **kw):
2384 binary = binary._clone()
2385 percent = self._like_percent_literal
2386 binary.right = percent.concat(binary.right)
2387 return self.visit_not_like_op_binary(binary, operator, **kw)
2388
2389 def visit_like_op_binary(self, binary, operator, **kw):
2390 escape = binary.modifiers.get("escape", None)
2391
2392 # TODO: use ternary here, not "and"/ "or"
2393 return "%s LIKE %s" % (
2394 binary.left._compiler_dispatch(self, **kw),
2395 binary.right._compiler_dispatch(self, **kw),
2396 ) + (
2397 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
2398 if escape
2399 else ""
2400 )
2401
2402 def visit_not_like_op_binary(self, binary, operator, **kw):
2403 escape = binary.modifiers.get("escape", None)
2404 return "%s NOT LIKE %s" % (
2405 binary.left._compiler_dispatch(self, **kw),
2406 binary.right._compiler_dispatch(self, **kw),
2407 ) + (
2408 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
2409 if escape
2410 else ""
2411 )
2412
2413 def visit_ilike_op_binary(self, binary, operator, **kw):
2414 escape = binary.modifiers.get("escape", None)
2415 return "lower(%s) LIKE lower(%s)" % (
2416 binary.left._compiler_dispatch(self, **kw),
2417 binary.right._compiler_dispatch(self, **kw),
2418 ) + (
2419 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
2420 if escape
2421 else ""
2422 )
2423
2424 def visit_not_ilike_op_binary(self, binary, operator, **kw):
2425 escape = binary.modifiers.get("escape", None)
2426 return "lower(%s) NOT LIKE lower(%s)" % (
2427 binary.left._compiler_dispatch(self, **kw),
2428 binary.right._compiler_dispatch(self, **kw),
2429 ) + (
2430 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
2431 if escape
2432 else ""
2433 )
2434
2435 def visit_between_op_binary(self, binary, operator, **kw):
2436 symmetric = binary.modifiers.get("symmetric", False)
2437 return self._generate_generic_binary(
2438 binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
2439 )
2440
2441 def visit_not_between_op_binary(self, binary, operator, **kw):
2442 symmetric = binary.modifiers.get("symmetric", False)
2443 return self._generate_generic_binary(
2444 binary,
2445 " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
2446 **kw
2447 )
2448
2449 def visit_regexp_match_op_binary(self, binary, operator, **kw):
2450 raise exc.CompileError(
2451 "%s dialect does not support regular expressions"
2452 % self.dialect.name
2453 )
2454
2455 def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
2456 raise exc.CompileError(
2457 "%s dialect does not support regular expressions"
2458 % self.dialect.name
2459 )
2460
2461 def visit_regexp_replace_op_binary(self, binary, operator, **kw):
2462 raise exc.CompileError(
2463 "%s dialect does not support regular expression replacements"
2464 % self.dialect.name
2465 )
2466
2467 def visit_bindparam(
2468 self,
2469 bindparam,
2470 within_columns_clause=False,
2471 literal_binds=False,
2472 skip_bind_expression=False,
2473 literal_execute=False,
2474 render_postcompile=False,
2475 **kwargs
2476 ):
2477 if not skip_bind_expression:
2478 impl = bindparam.type.dialect_impl(self.dialect)
2479 if impl._has_bind_expression:
2480 bind_expression = impl.bind_expression(bindparam)
2481 wrapped = self.process(
2482 bind_expression,
2483 skip_bind_expression=True,
2484 within_columns_clause=within_columns_clause,
2485 literal_binds=literal_binds and not bindparam.expanding,
2486 literal_execute=literal_execute,
2487 render_postcompile=render_postcompile,
2488 **kwargs
2489 )
2490 if bindparam.expanding:
2491 # for postcompile w/ expanding, move the "wrapped" part
2492 # of this into the inside
2493
2494 m = re.match(
2495 r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
2496 )
2497 assert m, "unexpected format for expanding parameter"
2498 wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
2499 m.group(2),
2500 m.group(1),
2501 m.group(3),
2502 )
2503
2504 if literal_binds:
2505 ret = self.render_literal_bindparam(
2506 bindparam,
2507 within_columns_clause=True,
2508 bind_expression_template=wrapped,
2509 **kwargs
2510 )
2511 return "(%s)" % ret
2512
2513 return wrapped
2514
2515 if not literal_binds:
2516 literal_execute = (
2517 literal_execute
2518 or bindparam.literal_execute
2519 or (within_columns_clause and self.ansi_bind_rules)
2520 )
2521 post_compile = literal_execute or bindparam.expanding
2522 else:
2523 post_compile = False
2524
2525 if literal_binds:
2526 ret = self.render_literal_bindparam(
2527 bindparam, within_columns_clause=True, **kwargs
2528 )
2529 if bindparam.expanding:
2530 ret = "(%s)" % ret
2531 return ret
2532
2533 name = self._truncate_bindparam(bindparam)
2534
2535 if name in self.binds:
2536 existing = self.binds[name]
2537 if existing is not bindparam:
2538 if (
2539 (existing.unique or bindparam.unique)
2540 and not existing.proxy_set.intersection(
2541 bindparam.proxy_set
2542 )
2543 and not existing._cloned_set.intersection(
2544 bindparam._cloned_set
2545 )
2546 ):
2547 raise exc.CompileError(
2548 "Bind parameter '%s' conflicts with "
2549 "unique bind parameter of the same name" % name
2550 )
2551 elif existing.expanding != bindparam.expanding:
2552 raise exc.CompileError(
2553 "Can't reuse bound parameter name '%s' in both "
2554 "'expanding' (e.g. within an IN expression) and "
2555 "non-expanding contexts. If this parameter is to "
2556 "receive a list/array value, set 'expanding=True' on "
2557 "it for expressions that aren't IN, otherwise use "
2558 "a different parameter name." % (name,)
2559 )
2560 elif existing._is_crud or bindparam._is_crud:
2561 raise exc.CompileError(
2562 "bindparam() name '%s' is reserved "
2563 "for automatic usage in the VALUES or SET "
2564 "clause of this "
2565 "insert/update statement. Please use a "
2566 "name other than column name when using bindparam() "
2567 "with insert() or update() (for example, 'b_%s')."
2568 % (bindparam.key, bindparam.key)
2569 )
2570
2571 self.binds[bindparam.key] = self.binds[name] = bindparam
2572
2573 # if we are given a cache key that we're going to match against,
2574 # relate the bindparam here to one that is most likely present
2575 # in the "extracted params" portion of the cache key. this is used
2576 # to set up a positional mapping that is used to determine the
2577 # correct parameters for a subsequent use of this compiled with
2578 # a different set of parameter values. here, we accommodate for
2579 # parameters that may have been cloned both before and after the cache
2580 # key was been generated.
2581 ckbm = self._cache_key_bind_match
2582 if ckbm:
2583 for bp in bindparam._cloned_set:
2584 if bp.key in ckbm:
2585 cb = ckbm[bp.key]
2586 ckbm[cb].append(bindparam)
2587
2588 if bindparam.isoutparam:
2589 self.has_out_parameters = True
2590
2591 if post_compile:
2592 if render_postcompile:
2593 self._render_postcompile = True
2594
2595 if literal_execute:
2596 self.literal_execute_params |= {bindparam}
2597 else:
2598 self.post_compile_params |= {bindparam}
2599
2600 ret = self.bindparam_string(
2601 name,
2602 post_compile=post_compile,
2603 expanding=bindparam.expanding,
2604 **kwargs
2605 )
2606
2607 if bindparam.expanding:
2608 ret = "(%s)" % ret
2609 return ret
2610
2611 def render_literal_bindparam(
2612 self,
2613 bindparam,
2614 render_literal_value=NO_ARG,
2615 bind_expression_template=None,
2616 **kw
2617 ):
2618 if render_literal_value is not NO_ARG:
2619 value = render_literal_value
2620 else:
2621 if bindparam.value is None and bindparam.callable is None:
2622 op = kw.get("_binary_op", None)
2623 if op and op not in (operators.is_, operators.is_not):
2624 util.warn_limited(
2625 "Bound parameter '%s' rendering literal NULL in a SQL "
2626 "expression; comparisons to NULL should not use "
2627 "operators outside of 'is' or 'is not'",
2628 (bindparam.key,),
2629 )
2630 return self.process(sqltypes.NULLTYPE, **kw)
2631 value = bindparam.effective_value
2632
2633 if bindparam.expanding:
2634 leep = self._literal_execute_expanding_parameter_literal_binds
2635 to_update, replacement_expr = leep(
2636 bindparam,
2637 value,
2638 bind_expression_template=bind_expression_template,
2639 )
2640 return replacement_expr
2641 else:
2642 return self.render_literal_value(value, bindparam.type)
2643
2644 def render_literal_value(self, value, type_):
2645 """Render the value of a bind parameter as a quoted literal.
2646
2647 This is used for statement sections that do not accept bind parameters
2648 on the target driver/database.
2649
2650 This should be implemented by subclasses using the quoting services
2651 of the DBAPI.
2652
2653 """
2654
2655 processor = type_._cached_literal_processor(self.dialect)
2656 if processor:
2657 try:
2658 return processor(value)
2659 except Exception as e:
2660 util.raise_(
2661 exc.CompileError(
2662 "Could not render literal value "
2663 '"%s" '
2664 "with datatype "
2665 "%s; see parent stack trace for "
2666 "more detail."
2667 % (
2668 sql_util._repr_single_value(value),
2669 type_,
2670 )
2671 ),
2672 from_=e,
2673 )
2674
2675 else:
2676 raise exc.CompileError(
2677 "No literal value renderer is available for literal value "
2678 '"%s" with datatype %s'
2679 % (sql_util._repr_single_value(value), type_)
2680 )
2681
2682 def _truncate_bindparam(self, bindparam):
2683 if bindparam in self.bind_names:
2684 return self.bind_names[bindparam]
2685
2686 bind_name = bindparam.key
2687 if isinstance(bind_name, elements._truncated_label):
2688 bind_name = self._truncated_identifier("bindparam", bind_name)
2689
2690 # add to bind_names for translation
2691 self.bind_names[bindparam] = bind_name
2692
2693 return bind_name
2694
2695 def _truncated_identifier(self, ident_class, name):
2696 if (ident_class, name) in self.truncated_names:
2697 return self.truncated_names[(ident_class, name)]
2698
2699 anonname = name.apply_map(self.anon_map)
2700
2701 if len(anonname) > self.label_length - 6:
2702 counter = self.truncated_names.get(ident_class, 1)
2703 truncname = (
2704 anonname[0 : max(self.label_length - 6, 0)]
2705 + "_"
2706 + hex(counter)[2:]
2707 )
2708 self.truncated_names[ident_class] = counter + 1
2709 else:
2710 truncname = anonname
2711 self.truncated_names[(ident_class, name)] = truncname
2712 return truncname
2713
2714 def _anonymize(self, name):
2715 return name % self.anon_map
2716
2717 def bindparam_string(
2718 self,
2719 name,
2720 positional_names=None,
2721 post_compile=False,
2722 expanding=False,
2723 escaped_from=None,
2724 **kw
2725 ):
2726
2727 if self.positional:
2728 if positional_names is not None:
2729 positional_names.append(name)
2730 else:
2731 self.positiontup.append(name)
2732 self.positiontup_level[name] = len(self.stack)
2733 if not escaped_from:
2734
2735 if _BIND_TRANSLATE_RE.search(name):
2736 # not quite the translate use case as we want to
2737 # also get a quick boolean if we even found
2738 # unusual characters in the name
2739 new_name = _BIND_TRANSLATE_RE.sub(
2740 lambda m: _BIND_TRANSLATE_CHARS[m.group(0)],
2741 name,
2742 )
2743 escaped_from = name
2744 name = new_name
2745
2746 if escaped_from:
2747 if not self.escaped_bind_names:
2748 self.escaped_bind_names = {}
2749 self.escaped_bind_names[escaped_from] = name
2750 if post_compile:
2751 return "__[POSTCOMPILE_%s]" % name
2752 else:
2753 return self.bindtemplate % {"name": name}
2754
2755 def visit_cte(
2756 self,
2757 cte,
2758 asfrom=False,
2759 ashint=False,
2760 fromhints=None,
2761 visiting_cte=None,
2762 from_linter=None,
2763 **kwargs
2764 ):
2765 self._init_cte_state()
2766
2767 kwargs["visiting_cte"] = cte
2768
2769 cte_name = cte.name
2770
2771 if isinstance(cte_name, elements._truncated_label):
2772 cte_name = self._truncated_identifier("alias", cte_name)
2773
2774 is_new_cte = True
2775 embedded_in_current_named_cte = False
2776
2777 _reference_cte = cte._get_reference_cte()
2778
2779 if _reference_cte in self.level_name_by_cte:
2780 cte_level, _ = self.level_name_by_cte[_reference_cte]
2781 assert _ == cte_name
2782 else:
2783 cte_level = len(self.stack) if cte.nesting else 1
2784
2785 cte_level_name = (cte_level, cte_name)
2786 if cte_level_name in self.ctes_by_level_name:
2787 existing_cte = self.ctes_by_level_name[cte_level_name]
2788 embedded_in_current_named_cte = visiting_cte is existing_cte
2789
2790 # we've generated a same-named CTE that we are enclosed in,
2791 # or this is the same CTE. just return the name.
2792 if cte is existing_cte._restates or cte is existing_cte:
2793 is_new_cte = False
2794 elif existing_cte is cte._restates:
2795 # we've generated a same-named CTE that is
2796 # enclosed in us - we take precedence, so
2797 # discard the text for the "inner".
2798 del self.ctes[existing_cte]
2799
2800 existing_cte_reference_cte = existing_cte._get_reference_cte()
2801
2802 # TODO: determine if these assertions are correct. they
2803 # pass for current test cases
2804 # assert existing_cte_reference_cte is _reference_cte
2805 # assert existing_cte_reference_cte is existing_cte
2806
2807 del self.level_name_by_cte[existing_cte_reference_cte]
2808 else:
2809 # if the two CTEs are deep-copy identical, consider them
2810 # the same, **if** they are clones, that is, they came from
2811 # the ORM or other visit method
2812 if (
2813 cte._is_clone_of is not None
2814 or existing_cte._is_clone_of is not None
2815 ) and cte.compare(existing_cte):
2816 is_new_cte = False
2817 else:
2818 raise exc.CompileError(
2819 "Multiple, unrelated CTEs found with "
2820 "the same name: %r" % cte_name
2821 )
2822
2823 if not asfrom and not is_new_cte:
2824 return None
2825
2826 if cte._cte_alias is not None:
2827 pre_alias_cte = cte._cte_alias
2828 cte_pre_alias_name = cte._cte_alias.name
2829 if isinstance(cte_pre_alias_name, elements._truncated_label):
2830 cte_pre_alias_name = self._truncated_identifier(
2831 "alias", cte_pre_alias_name
2832 )
2833 else:
2834 pre_alias_cte = cte
2835 cte_pre_alias_name = None
2836
2837 if is_new_cte:
2838 self.ctes_by_level_name[cte_level_name] = cte
2839 self.level_name_by_cte[_reference_cte] = cte_level_name
2840
2841 if (
2842 "autocommit" in cte.element._execution_options
2843 and "autocommit" not in self.execution_options
2844 ):
2845 self.execution_options = self.execution_options.union(
2846 {
2847 "autocommit": cte.element._execution_options[
2848 "autocommit"
2849 ]
2850 }
2851 )
2852 if self.positional:
2853 self.cte_level[cte] = cte_level
2854
2855 if pre_alias_cte not in self.ctes:
2856 self.visit_cte(pre_alias_cte, **kwargs)
2857
2858 if not cte_pre_alias_name and cte not in self.ctes:
2859 if cte.recursive:
2860 self.ctes_recursive = True
2861 text = self.preparer.format_alias(cte, cte_name)
2862 if cte.recursive:
2863 if isinstance(cte.element, selectable.Select):
2864 col_source = cte.element
2865 elif isinstance(cte.element, selectable.CompoundSelect):
2866 col_source = cte.element.selects[0]
2867 else:
2868 assert False, "cte should only be against SelectBase"
2869
2870 # TODO: can we get at the .columns_plus_names collection
2871 # that is already (or will be?) generated for the SELECT
2872 # rather than calling twice?
2873 recur_cols = [
2874 # TODO: proxy_name is not technically safe,
2875 # see test_cte->
2876 # test_with_recursive_no_name_currently_buggy. not
2877 # clear what should be done with such a case
2878 fallback_label_name or proxy_name
2879 for (
2880 _,
2881 proxy_name,
2882 fallback_label_name,
2883 c,
2884 repeated,
2885 ) in (col_source._generate_columns_plus_names(True))
2886 if not repeated
2887 ]
2888
2889 text += "(%s)" % (
2890 ", ".join(
2891 self.preparer.format_label_name(
2892 ident, anon_map=self.anon_map
2893 )
2894 for ident in recur_cols
2895 )
2896 )
2897
2898 if self.positional:
2899 kwargs["positional_names"] = self.cte_positional[cte] = []
2900
2901 assert kwargs.get("subquery", False) is False
2902
2903 if not self.stack:
2904 # toplevel, this is a stringify of the
2905 # cte directly. just compile the inner
2906 # the way alias() does.
2907 return cte.element._compiler_dispatch(
2908 self, asfrom=asfrom, **kwargs
2909 )
2910 else:
2911 prefixes = self._generate_prefixes(
2912 cte, cte._prefixes, **kwargs
2913 )
2914 inner = cte.element._compiler_dispatch(
2915 self, asfrom=True, **kwargs
2916 )
2917
2918 text += " AS %s\n(%s)" % (prefixes, inner)
2919
2920 if cte._suffixes:
2921 text += " " + self._generate_prefixes(
2922 cte, cte._suffixes, **kwargs
2923 )
2924
2925 self.ctes[cte] = text
2926
2927 if asfrom:
2928 if from_linter:
2929 from_linter.froms[cte] = cte_name
2930
2931 if not is_new_cte and embedded_in_current_named_cte:
2932 return self.preparer.format_alias(cte, cte_name)
2933
2934 if cte_pre_alias_name:
2935 text = self.preparer.format_alias(cte, cte_pre_alias_name)
2936 if self.preparer._requires_quotes(cte_name):
2937 cte_name = self.preparer.quote(cte_name)
2938 text += self.get_render_as_alias_suffix(cte_name)
2939 return text
2940 else:
2941 return self.preparer.format_alias(cte, cte_name)
2942
2943 def visit_table_valued_alias(self, element, **kw):
2944 if element.joins_implicitly:
2945 kw["from_linter"] = None
2946 if element._is_lateral:
2947 return self.visit_lateral(element, **kw)
2948 else:
2949 return self.visit_alias(element, **kw)
2950
2951 def visit_table_valued_column(self, element, **kw):
2952 return self.visit_column(element, **kw)
2953
2954 def visit_alias(
2955 self,
2956 alias,
2957 asfrom=False,
2958 ashint=False,
2959 iscrud=False,
2960 fromhints=None,
2961 subquery=False,
2962 lateral=False,
2963 enclosing_alias=None,
2964 from_linter=None,
2965 **kwargs
2966 ):
2967
2968 if lateral:
2969 if "enclosing_lateral" not in kwargs:
2970 # if lateral is set and enclosing_lateral is not
2971 # present, we assume we are being called directly
2972 # from visit_lateral() and we need to set enclosing_lateral.
2973 assert alias._is_lateral
2974 kwargs["enclosing_lateral"] = alias
2975
2976 # for lateral objects, we track a second from_linter that is...
2977 # lateral! to the level above us.
2978 if (
2979 from_linter
2980 and "lateral_from_linter" not in kwargs
2981 and "enclosing_lateral" in kwargs
2982 ):
2983 kwargs["lateral_from_linter"] = from_linter
2984
2985 if enclosing_alias is not None and enclosing_alias.element is alias:
2986 inner = alias.element._compiler_dispatch(
2987 self,
2988 asfrom=asfrom,
2989 ashint=ashint,
2990 iscrud=iscrud,
2991 fromhints=fromhints,
2992 lateral=lateral,
2993 enclosing_alias=alias,
2994 **kwargs
2995 )
2996 if subquery and (asfrom or lateral):
2997 inner = "(%s)" % (inner,)
2998 return inner
2999 else:
3000 enclosing_alias = kwargs["enclosing_alias"] = alias
3001
3002 if asfrom or ashint:
3003 if isinstance(alias.name, elements._truncated_label):
3004 alias_name = self._truncated_identifier("alias", alias.name)
3005 else:
3006 alias_name = alias.name
3007
3008 if ashint:
3009 return self.preparer.format_alias(alias, alias_name)
3010 elif asfrom:
3011 if from_linter:
3012 from_linter.froms[alias] = alias_name
3013
3014 inner = alias.element._compiler_dispatch(
3015 self, asfrom=True, lateral=lateral, **kwargs
3016 )
3017 if subquery:
3018 inner = "(%s)" % (inner,)
3019
3020 ret = inner + self.get_render_as_alias_suffix(
3021 self.preparer.format_alias(alias, alias_name)
3022 )
3023
3024 if alias._supports_derived_columns and alias._render_derived:
3025 ret += "(%s)" % (
3026 ", ".join(
3027 "%s%s"
3028 % (
3029 self.preparer.quote(col.name),
3030 " %s"
3031 % self.dialect.type_compiler.process(
3032 col.type, **kwargs
3033 )
3034 if alias._render_derived_w_types
3035 else "",
3036 )
3037 for col in alias.c
3038 )
3039 )
3040
3041 if fromhints and alias in fromhints:
3042 ret = self.format_from_hint_text(
3043 ret, alias, fromhints[alias], iscrud
3044 )
3045
3046 return ret
3047 else:
3048 # note we cancel the "subquery" flag here as well
3049 return alias.element._compiler_dispatch(
3050 self, lateral=lateral, **kwargs
3051 )
3052
3053 def visit_subquery(self, subquery, **kw):
3054 kw["subquery"] = True
3055 return self.visit_alias(subquery, **kw)
3056
3057 def visit_lateral(self, lateral_, **kw):
3058 kw["lateral"] = True
3059 return "LATERAL %s" % self.visit_alias(lateral_, **kw)
3060
3061 def visit_tablesample(self, tablesample, asfrom=False, **kw):
3062 text = "%s TABLESAMPLE %s" % (
3063 self.visit_alias(tablesample, asfrom=True, **kw),
3064 tablesample._get_method()._compiler_dispatch(self, **kw),
3065 )
3066
3067 if tablesample.seed is not None:
3068 text += " REPEATABLE (%s)" % (
3069 tablesample.seed._compiler_dispatch(self, **kw)
3070 )
3071
3072 return text
3073
3074 def visit_values(self, element, asfrom=False, from_linter=None, **kw):
3075 kw.setdefault("literal_binds", element.literal_binds)
3076 v = "VALUES %s" % ", ".join(
3077 self.process(
3078 elements.Tuple(
3079 types=element._column_types, *elem
3080 ).self_group(),
3081 **kw
3082 )
3083 for chunk in element._data
3084 for elem in chunk
3085 )
3086
3087 if isinstance(element.name, elements._truncated_label):
3088 name = self._truncated_identifier("values", element.name)
3089 else:
3090 name = element.name
3091
3092 if element._is_lateral:
3093 lateral = "LATERAL "
3094 else:
3095 lateral = ""
3096
3097 if asfrom:
3098 if from_linter:
3099 from_linter.froms[element] = (
3100 name if name is not None else "(unnamed VALUES element)"
3101 )
3102
3103 if name:
3104 v = "%s(%s)%s (%s)" % (
3105 lateral,
3106 v,
3107 self.get_render_as_alias_suffix(self.preparer.quote(name)),
3108 (
3109 ", ".join(
3110 c._compiler_dispatch(
3111 self, include_table=False, **kw
3112 )
3113 for c in element.columns
3114 )
3115 ),
3116 )
3117 else:
3118 v = "%s(%s)" % (lateral, v)
3119 return v
3120
3121 def get_render_as_alias_suffix(self, alias_name_text):
3122 return " AS " + alias_name_text
3123
3124 def _add_to_result_map(self, keyname, name, objects, type_):
3125 if keyname is None or keyname == "*":
3126 self._ordered_columns = False
3127 self._ad_hoc_textual = True
3128 if type_._is_tuple_type:
3129 raise exc.CompileError(
3130 "Most backends don't support SELECTing "
3131 "from a tuple() object. If this is an ORM query, "
3132 "consider using the Bundle object."
3133 )
3134 self._result_columns.append((keyname, name, objects, type_))
3135
3136 def _label_returning_column(
3137 self, stmt, column, column_clause_args=None, **kw
3138 ):
3139 """Render a column with necessary labels inside of a RETURNING clause.
3140
3141 This method is provided for individual dialects in place of calling
3142 the _label_select_column method directly, so that the two use cases
3143 of RETURNING vs. SELECT can be disambiguated going forward.
3144
3145 .. versionadded:: 1.4.21
3146
3147 """
3148 return self._label_select_column(
3149 None,
3150 column,
3151 True,
3152 False,
3153 {} if column_clause_args is None else column_clause_args,
3154 **kw
3155 )
3156
3157 def _label_select_column(
3158 self,
3159 select,
3160 column,
3161 populate_result_map,
3162 asfrom,
3163 column_clause_args,
3164 name=None,
3165 proxy_name=None,
3166 fallback_label_name=None,
3167 within_columns_clause=True,
3168 column_is_repeated=False,
3169 need_column_expressions=False,
3170 ):
3171 """produce labeled columns present in a select()."""
3172 impl = column.type.dialect_impl(self.dialect)
3173
3174 if impl._has_column_expression and (
3175 need_column_expressions or populate_result_map
3176 ):
3177 col_expr = impl.column_expression(column)
3178 else:
3179 col_expr = column
3180
3181 if populate_result_map:
3182 # pass an "add_to_result_map" callable into the compilation
3183 # of embedded columns. this collects information about the
3184 # column as it will be fetched in the result and is coordinated
3185 # with cursor.description when the query is executed.
3186 add_to_result_map = self._add_to_result_map
3187
3188 # if the SELECT statement told us this column is a repeat,
3189 # wrap the callable with one that prevents the addition of the
3190 # targets
3191 if column_is_repeated:
3192 _add_to_result_map = add_to_result_map
3193
3194 def add_to_result_map(keyname, name, objects, type_):
3195 _add_to_result_map(keyname, name, (), type_)
3196
3197 # if we redefined col_expr for type expressions, wrap the
3198 # callable with one that adds the original column to the targets
3199 elif col_expr is not column:
3200 _add_to_result_map = add_to_result_map
3201
3202 def add_to_result_map(keyname, name, objects, type_):
3203 _add_to_result_map(
3204 keyname, name, (column,) + objects, type_
3205 )
3206
3207 else:
3208 add_to_result_map = None
3209
3210 # this method is used by some of the dialects for RETURNING,
3211 # which has different inputs. _label_returning_column was added
3212 # as the better target for this now however for 1.4 we will keep
3213 # _label_select_column directly compatible with this use case.
3214 # these assertions right now set up the current expected inputs
3215 assert within_columns_clause, (
3216 "_label_select_column is only relevant within "
3217 "the columns clause of a SELECT or RETURNING"
3218 )
3219 if isinstance(column, elements.Label):
3220 if col_expr is not column:
3221 result_expr = _CompileLabel(
3222 col_expr, column.name, alt_names=(column.element,)
3223 )
3224 else:
3225 result_expr = col_expr
3226
3227 elif name:
3228 # here, _columns_plus_names has determined there's an explicit
3229 # label name we need to use. this is the default for
3230 # tablenames_plus_columnnames as well as when columns are being
3231 # deduplicated on name
3232
3233 assert (
3234 proxy_name is not None
3235 ), "proxy_name is required if 'name' is passed"
3236
3237 result_expr = _CompileLabel(
3238 col_expr,
3239 name,
3240 alt_names=(
3241 proxy_name,
3242 # this is a hack to allow legacy result column lookups
3243 # to work as they did before; this goes away in 2.0.
3244 # TODO: this only seems to be tested indirectly
3245 # via test/orm/test_deprecations.py. should be a
3246 # resultset test for this
3247 column._tq_label,
3248 ),
3249 )
3250 else:
3251 # determine here whether this column should be rendered in
3252 # a labelled context or not, as we were given no required label
3253 # name from the caller. Here we apply heuristics based on the kind
3254 # of SQL expression involved.
3255
3256 if col_expr is not column:
3257 # type-specific expression wrapping the given column,
3258 # so we render a label
3259 render_with_label = True
3260 elif isinstance(column, elements.ColumnClause):
3261 # table-bound column, we render its name as a label if we are
3262 # inside of a subquery only
3263 render_with_label = (
3264 asfrom
3265 and not column.is_literal
3266 and column.table is not None
3267 )
3268 elif isinstance(column, elements.TextClause):
3269 render_with_label = False
3270 elif isinstance(column, elements.UnaryExpression):
3271 render_with_label = column.wraps_column_expression or asfrom
3272 elif (
3273 # general class of expressions that don't have a SQL-column
3274 # addressible name. includes scalar selects, bind parameters,
3275 # SQL functions, others
3276 not isinstance(column, elements.NamedColumn)
3277 # deeper check that indicates there's no natural "name" to
3278 # this element, which accommodates for custom SQL constructs
3279 # that might have a ".name" attribute (but aren't SQL
3280 # functions) but are not implementing this more recently added
3281 # base class. in theory the "NamedColumn" check should be
3282 # enough, however here we seek to maintain legacy behaviors
3283 # as well.
3284 and column._non_anon_label is None
3285 ):
3286 render_with_label = True
3287 else:
3288 render_with_label = False
3289
3290 if render_with_label:
3291 if not fallback_label_name:
3292 # used by the RETURNING case right now. we generate it
3293 # here as 3rd party dialects may be referring to
3294 # _label_select_column method directly instead of the
3295 # just-added _label_returning_column method
3296 assert not column_is_repeated
3297 fallback_label_name = column._anon_name_label
3298
3299 fallback_label_name = (
3300 elements._truncated_label(fallback_label_name)
3301 if not isinstance(
3302 fallback_label_name, elements._truncated_label
3303 )
3304 else fallback_label_name
3305 )
3306
3307 result_expr = _CompileLabel(
3308 col_expr, fallback_label_name, alt_names=(proxy_name,)
3309 )
3310 else:
3311 result_expr = col_expr
3312
3313 column_clause_args.update(
3314 within_columns_clause=within_columns_clause,
3315 add_to_result_map=add_to_result_map,
3316 )
3317 return result_expr._compiler_dispatch(self, **column_clause_args)
3318
3319 def format_from_hint_text(self, sqltext, table, hint, iscrud):
3320 hinttext = self.get_from_hint_text(table, hint)
3321 if hinttext:
3322 sqltext += " " + hinttext
3323 return sqltext
3324
3325 def get_select_hint_text(self, byfroms):
3326 return None
3327
3328 def get_from_hint_text(self, table, text):
3329 return None
3330
3331 def get_crud_hint_text(self, table, text):
3332 return None
3333
3334 def get_statement_hint_text(self, hint_texts):
3335 return " ".join(hint_texts)
3336
3337 _default_stack_entry = util.immutabledict(
3338 [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
3339 )
3340
3341 def _display_froms_for_select(
3342 self, select_stmt, asfrom, lateral=False, **kw
3343 ):
3344 # utility method to help external dialects
3345 # get the correct from list for a select.
3346 # specifically the oracle dialect needs this feature
3347 # right now.
3348 toplevel = not self.stack
3349 entry = self._default_stack_entry if toplevel else self.stack[-1]
3350
3351 compile_state = select_stmt._compile_state_factory(select_stmt, self)
3352
3353 correlate_froms = entry["correlate_froms"]
3354 asfrom_froms = entry["asfrom_froms"]
3355
3356 if asfrom and not lateral:
3357 froms = compile_state._get_display_froms(
3358 explicit_correlate_froms=correlate_froms.difference(
3359 asfrom_froms
3360 ),
3361 implicit_correlate_froms=(),
3362 )
3363 else:
3364 froms = compile_state._get_display_froms(
3365 explicit_correlate_froms=correlate_froms,
3366 implicit_correlate_froms=asfrom_froms,
3367 )
3368 return froms
3369
3370 translate_select_structure = None
3371 """if not ``None``, should be a callable which accepts ``(select_stmt,
3372 **kw)`` and returns a select object. this is used for structural changes
3373 mostly to accommodate for LIMIT/OFFSET schemes
3374
3375 """
3376
3377 def visit_select(
3378 self,
3379 select_stmt,
3380 asfrom=False,
3381 insert_into=False,
3382 fromhints=None,
3383 compound_index=None,
3384 select_wraps_for=None,
3385 lateral=False,
3386 from_linter=None,
3387 **kwargs
3388 ):
3389 assert select_wraps_for is None, (
3390 "SQLAlchemy 1.4 requires use of "
3391 "the translate_select_structure hook for structural "
3392 "translations of SELECT objects"
3393 )
3394
3395 # initial setup of SELECT. the compile_state_factory may now
3396 # be creating a totally different SELECT from the one that was
3397 # passed in. for ORM use this will convert from an ORM-state
3398 # SELECT to a regular "Core" SELECT. other composed operations
3399 # such as computation of joins will be performed.
3400
3401 kwargs["within_columns_clause"] = False
3402
3403 compile_state = select_stmt._compile_state_factory(
3404 select_stmt, self, **kwargs
3405 )
3406 select_stmt = compile_state.statement
3407
3408 toplevel = not self.stack
3409
3410 if toplevel and not self.compile_state:
3411 self.compile_state = compile_state
3412
3413 is_embedded_select = compound_index is not None or insert_into
3414
3415 # translate step for Oracle, SQL Server which often need to
3416 # restructure the SELECT to allow for LIMIT/OFFSET and possibly
3417 # other conditions
3418 if self.translate_select_structure:
3419 new_select_stmt = self.translate_select_structure(
3420 select_stmt, asfrom=asfrom, **kwargs
3421 )
3422
3423 # if SELECT was restructured, maintain a link to the originals
3424 # and assemble a new compile state
3425 if new_select_stmt is not select_stmt:
3426 compile_state_wraps_for = compile_state
3427 select_wraps_for = select_stmt
3428 select_stmt = new_select_stmt
3429
3430 compile_state = select_stmt._compile_state_factory(
3431 select_stmt, self, **kwargs
3432 )
3433 select_stmt = compile_state.statement
3434
3435 entry = self._default_stack_entry if toplevel else self.stack[-1]
3436
3437 populate_result_map = need_column_expressions = (
3438 toplevel
3439 or entry.get("need_result_map_for_compound", False)
3440 or entry.get("need_result_map_for_nested", False)
3441 )
3442
3443 # indicates there is a CompoundSelect in play and we are not the
3444 # first select
3445 if compound_index:
3446 populate_result_map = False
3447
3448 # this was first proposed as part of #3372; however, it is not
3449 # reached in current tests and could possibly be an assertion
3450 # instead.
3451 if not populate_result_map and "add_to_result_map" in kwargs:
3452 del kwargs["add_to_result_map"]
3453
3454 froms = self._setup_select_stack(
3455 select_stmt, compile_state, entry, asfrom, lateral, compound_index
3456 )
3457
3458 column_clause_args = kwargs.copy()
3459 column_clause_args.update(
3460 {"within_label_clause": False, "within_columns_clause": False}
3461 )
3462
3463 text = "SELECT " # we're off to a good start !
3464
3465 if select_stmt._hints:
3466 hint_text, byfrom = self._setup_select_hints(select_stmt)
3467 if hint_text:
3468 text += hint_text + " "
3469 else:
3470 byfrom = None
3471
3472 if select_stmt._independent_ctes:
3473 for cte in select_stmt._independent_ctes:
3474 cte._compiler_dispatch(self, **kwargs)
3475
3476 if select_stmt._prefixes:
3477 text += self._generate_prefixes(
3478 select_stmt, select_stmt._prefixes, **kwargs
3479 )
3480
3481 text += self.get_select_precolumns(select_stmt, **kwargs)
3482 # the actual list of columns to print in the SELECT column list.
3483 inner_columns = [
3484 c
3485 for c in [
3486 self._label_select_column(
3487 select_stmt,
3488 column,
3489 populate_result_map,
3490 asfrom,
3491 column_clause_args,
3492 name=name,
3493 proxy_name=proxy_name,
3494 fallback_label_name=fallback_label_name,
3495 column_is_repeated=repeated,
3496 need_column_expressions=need_column_expressions,
3497 )
3498 for (
3499 name,
3500 proxy_name,
3501 fallback_label_name,
3502 column,
3503 repeated,
3504 ) in compile_state.columns_plus_names
3505 ]
3506 if c is not None
3507 ]
3508
3509 if populate_result_map and select_wraps_for is not None:
3510 # if this select was generated from translate_select,
3511 # rewrite the targeted columns in the result map
3512
3513 translate = dict(
3514 zip(
3515 [
3516 name
3517 for (
3518 key,
3519 proxy_name,
3520 fallback_label_name,
3521 name,
3522 repeated,
3523 ) in compile_state.columns_plus_names
3524 ],
3525 [
3526 name
3527 for (
3528 key,
3529 proxy_name,
3530 fallback_label_name,
3531 name,
3532 repeated,
3533 ) in compile_state_wraps_for.columns_plus_names
3534 ],
3535 )
3536 )
3537
3538 self._result_columns = [
3539 (key, name, tuple(translate.get(o, o) for o in obj), type_)
3540 for key, name, obj, type_ in self._result_columns
3541 ]
3542
3543 text = self._compose_select_body(
3544 text,
3545 select_stmt,
3546 compile_state,
3547 inner_columns,
3548 froms,
3549 byfrom,
3550 toplevel,
3551 kwargs,
3552 )
3553
3554 if select_stmt._statement_hints:
3555 per_dialect = [
3556 ht
3557 for (dialect_name, ht) in select_stmt._statement_hints
3558 if dialect_name in ("*", self.dialect.name)
3559 ]
3560 if per_dialect:
3561 text += " " + self.get_statement_hint_text(per_dialect)
3562
3563 # In compound query, CTEs are shared at the compound level
3564 if self.ctes and (not is_embedded_select or toplevel):
3565 nesting_level = len(self.stack) if not toplevel else None
3566 text = (
3567 self._render_cte_clause(
3568 nesting_level=nesting_level,
3569 visiting_cte=kwargs.get("visiting_cte"),
3570 )
3571 + text
3572 )
3573
3574 if select_stmt._suffixes:
3575 text += " " + self._generate_prefixes(
3576 select_stmt, select_stmt._suffixes, **kwargs
3577 )
3578
3579 self.stack.pop(-1)
3580
3581 return text
3582
3583 def _setup_select_hints(self, select):
3584 byfrom = dict(
3585 [
3586 (
3587 from_,
3588 hinttext
3589 % {"name": from_._compiler_dispatch(self, ashint=True)},
3590 )
3591 for (from_, dialect), hinttext in select._hints.items()
3592 if dialect in ("*", self.dialect.name)
3593 ]
3594 )
3595 hint_text = self.get_select_hint_text(byfrom)
3596 return hint_text, byfrom
3597
3598 def _setup_select_stack(
3599 self, select, compile_state, entry, asfrom, lateral, compound_index
3600 ):
3601 correlate_froms = entry["correlate_froms"]
3602 asfrom_froms = entry["asfrom_froms"]
3603
3604 if compound_index == 0:
3605 entry["select_0"] = select
3606 elif compound_index:
3607 select_0 = entry["select_0"]
3608 numcols = len(select_0._all_selected_columns)
3609
3610 if len(compile_state.columns_plus_names) != numcols:
3611 raise exc.CompileError(
3612 "All selectables passed to "
3613 "CompoundSelect must have identical numbers of "
3614 "columns; select #%d has %d columns, select "
3615 "#%d has %d"
3616 % (
3617 1,
3618 numcols,
3619 compound_index + 1,
3620 len(select._all_selected_columns),
3621 )
3622 )
3623
3624 if asfrom and not lateral:
3625 froms = compile_state._get_display_froms(
3626 explicit_correlate_froms=correlate_froms.difference(
3627 asfrom_froms
3628 ),
3629 implicit_correlate_froms=(),
3630 )
3631 else:
3632 froms = compile_state._get_display_froms(
3633 explicit_correlate_froms=correlate_froms,
3634 implicit_correlate_froms=asfrom_froms,
3635 )
3636
3637 new_correlate_froms = set(selectable._from_objects(*froms))
3638 all_correlate_froms = new_correlate_froms.union(correlate_froms)
3639
3640 new_entry = {
3641 "asfrom_froms": new_correlate_froms,
3642 "correlate_froms": all_correlate_froms,
3643 "selectable": select,
3644 "compile_state": compile_state,
3645 }
3646 self.stack.append(new_entry)
3647
3648 return froms
3649
3650 def _compose_select_body(
3651 self,
3652 text,
3653 select,
3654 compile_state,
3655 inner_columns,
3656 froms,
3657 byfrom,
3658 toplevel,
3659 kwargs,
3660 ):
3661 text += ", ".join(inner_columns)
3662
3663 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
3664 from_linter = FromLinter({}, set())
3665 warn_linting = self.linting & WARN_LINTING
3666 if toplevel:
3667 self.from_linter = from_linter
3668 else:
3669 from_linter = None
3670 warn_linting = False
3671
3672 if froms:
3673 text += " \nFROM "
3674
3675 if select._hints:
3676 text += ", ".join(
3677 [
3678 f._compiler_dispatch(
3679 self,
3680 asfrom=True,
3681 fromhints=byfrom,
3682 from_linter=from_linter,
3683 **kwargs
3684 )
3685 for f in froms
3686 ]
3687 )
3688 else:
3689 text += ", ".join(
3690 [
3691 f._compiler_dispatch(
3692 self,
3693 asfrom=True,
3694 from_linter=from_linter,
3695 **kwargs
3696 )
3697 for f in froms
3698 ]
3699 )
3700 else:
3701 text += self.default_from()
3702
3703 if select._where_criteria:
3704 t = self._generate_delimited_and_list(
3705 select._where_criteria, from_linter=from_linter, **kwargs
3706 )
3707 if t:
3708 text += " \nWHERE " + t
3709
3710 if warn_linting:
3711 from_linter.warn()
3712
3713 if select._group_by_clauses:
3714 text += self.group_by_clause(select, **kwargs)
3715
3716 if select._having_criteria:
3717 t = self._generate_delimited_and_list(
3718 select._having_criteria, **kwargs
3719 )
3720 if t:
3721 text += " \nHAVING " + t
3722
3723 if select._order_by_clauses:
3724 text += self.order_by_clause(select, **kwargs)
3725
3726 if select._has_row_limiting_clause:
3727 text += self._row_limit_clause(select, **kwargs)
3728
3729 if select._for_update_arg is not None:
3730 text += self.for_update_clause(select, **kwargs)
3731
3732 return text
3733
3734 def _generate_prefixes(self, stmt, prefixes, **kw):
3735 clause = " ".join(
3736 prefix._compiler_dispatch(self, **kw)
3737 for prefix, dialect_name in prefixes
3738 if dialect_name is None or dialect_name == self.dialect.name
3739 )
3740 if clause:
3741 clause += " "
3742 return clause
3743
3744 def _render_cte_clause(
3745 self,
3746 nesting_level=None,
3747 include_following_stack=False,
3748 visiting_cte=None,
3749 ):
3750 """
3751 include_following_stack
3752 Also render the nesting CTEs on the next stack. Useful for
3753 SQL structures like UNION or INSERT that can wrap SELECT
3754 statements containing nesting CTEs.
3755 """
3756 if not self.ctes:
3757 return ""
3758
3759 if nesting_level and nesting_level > 1:
3760 ctes = util.OrderedDict()
3761 for cte in list(self.ctes.keys()):
3762 cte_level, cte_name = self.level_name_by_cte[
3763 cte._get_reference_cte()
3764 ]
3765 is_rendered_level = cte_level == nesting_level or (
3766 include_following_stack and cte_level == nesting_level + 1
3767 )
3768 if not (cte.nesting and is_rendered_level):
3769 continue
3770
3771 ctes[cte] = self.ctes[cte]
3772
3773 else:
3774 ctes = self.ctes
3775
3776 if not ctes:
3777 return ""
3778 ctes_recursive = any([cte.recursive for cte in ctes])
3779
3780 if self.positional:
3781 self.cte_order[visiting_cte].extend(ctes)
3782
3783 if visiting_cte is None and self.cte_order:
3784 assert self.positiontup is not None
3785
3786 def get_nested_positional(cte):
3787 if cte in self.cte_order:
3788 children = self.cte_order.pop(cte)
3789 to_add = list(
3790 itertools.chain.from_iterable(
3791 get_nested_positional(child_cte)
3792 for child_cte in children
3793 )
3794 )
3795 if cte in self.cte_positional:
3796 return reorder_positional(
3797 self.cte_positional[cte],
3798 to_add,
3799 self.cte_level[children[0]],
3800 )
3801 else:
3802 return to_add
3803 else:
3804 return self.cte_positional.get(cte, [])
3805
3806 def reorder_positional(pos, to_add, level):
3807 if not level:
3808 return to_add + pos
3809 index = 0
3810 for index, name in enumerate(reversed(pos)):
3811 if self.positiontup_level[name] < level: # type: ignore[index] # noqa: E501
3812 break
3813 return pos[:-index] + to_add + pos[-index:]
3814
3815 to_add = get_nested_positional(None)
3816 self.positiontup = reorder_positional(
3817 self.positiontup, to_add, nesting_level
3818 )
3819
3820 cte_text = self.get_cte_preamble(ctes_recursive) + " "
3821 cte_text += ", \n".join([txt for txt in ctes.values()])
3822 cte_text += "\n "
3823
3824 if nesting_level and nesting_level > 1:
3825 for cte in list(ctes.keys()):
3826 cte_level, cte_name = self.level_name_by_cte[
3827 cte._get_reference_cte()
3828 ]
3829 del self.ctes[cte]
3830 del self.ctes_by_level_name[(cte_level, cte_name)]
3831 del self.level_name_by_cte[cte._get_reference_cte()]
3832
3833 return cte_text
3834
3835 def get_cte_preamble(self, recursive):
3836 if recursive:
3837 return "WITH RECURSIVE"
3838 else:
3839 return "WITH"
3840
3841 def get_select_precolumns(self, select, **kw):
3842 """Called when building a ``SELECT`` statement, position is just
3843 before column list.
3844
3845 """
3846 if select._distinct_on:
3847 util.warn_deprecated(
3848 "DISTINCT ON is currently supported only by the PostgreSQL "
3849 "dialect. Use of DISTINCT ON for other backends is currently "
3850 "silently ignored, however this usage is deprecated, and will "
3851 "raise CompileError in a future release for all backends "
3852 "that do not support this syntax.",
3853 version="1.4",
3854 )
3855 return "DISTINCT " if select._distinct else ""
3856
3857 def group_by_clause(self, select, **kw):
3858 """allow dialects to customize how GROUP BY is rendered."""
3859
3860 group_by = self._generate_delimited_list(
3861 select._group_by_clauses, OPERATORS[operators.comma_op], **kw
3862 )
3863 if group_by:
3864 return " GROUP BY " + group_by
3865 else:
3866 return ""
3867
3868 def order_by_clause(self, select, **kw):
3869 """allow dialects to customize how ORDER BY is rendered."""
3870
3871 order_by = self._generate_delimited_list(
3872 select._order_by_clauses, OPERATORS[operators.comma_op], **kw
3873 )
3874
3875 if order_by:
3876 return " ORDER BY " + order_by
3877 else:
3878 return ""
3879
3880 def for_update_clause(self, select, **kw):
3881 return " FOR UPDATE"
3882
3883 def returning_clause(self, stmt, returning_cols):
3884 raise exc.CompileError(
3885 "RETURNING is not supported by this "
3886 "dialect's statement compiler."
3887 )
3888
3889 def limit_clause(self, select, **kw):
3890 text = ""
3891 if select._limit_clause is not None:
3892 text += "\n LIMIT " + self.process(select._limit_clause, **kw)
3893 if select._offset_clause is not None:
3894 if select._limit_clause is None:
3895 text += "\n LIMIT -1"
3896 text += " OFFSET " + self.process(select._offset_clause, **kw)
3897 return text
3898
3899 def fetch_clause(self, select, **kw):
3900 text = ""
3901 if select._offset_clause is not None:
3902 text += "\n OFFSET %s ROWS" % self.process(
3903 select._offset_clause, **kw
3904 )
3905 if select._fetch_clause is not None:
3906 text += "\n FETCH FIRST %s%s ROWS %s" % (
3907 self.process(select._fetch_clause, **kw),
3908 " PERCENT" if select._fetch_clause_options["percent"] else "",
3909 "WITH TIES"
3910 if select._fetch_clause_options["with_ties"]
3911 else "ONLY",
3912 )
3913 return text
3914
3915 def visit_table(
3916 self,
3917 table,
3918 asfrom=False,
3919 iscrud=False,
3920 ashint=False,
3921 fromhints=None,
3922 use_schema=True,
3923 from_linter=None,
3924 **kwargs
3925 ):
3926 if from_linter:
3927 from_linter.froms[table] = table.fullname
3928
3929 if asfrom or ashint:
3930 effective_schema = self.preparer.schema_for_object(table)
3931
3932 if use_schema and effective_schema:
3933 ret = (
3934 self.preparer.quote_schema(effective_schema)
3935 + "."
3936 + self.preparer.quote(table.name)
3937 )
3938 else:
3939 ret = self.preparer.quote(table.name)
3940 if fromhints and table in fromhints:
3941 ret = self.format_from_hint_text(
3942 ret, table, fromhints[table], iscrud
3943 )
3944 return ret
3945 else:
3946 return ""
3947
3948 def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
3949 if from_linter:
3950 from_linter.edges.update(
3951 itertools.product(
3952 join.left._from_objects, join.right._from_objects
3953 )
3954 )
3955
3956 if join.full:
3957 join_type = " FULL OUTER JOIN "
3958 elif join.isouter:
3959 join_type = " LEFT OUTER JOIN "
3960 else:
3961 join_type = " JOIN "
3962 return (
3963 join.left._compiler_dispatch(
3964 self, asfrom=True, from_linter=from_linter, **kwargs
3965 )
3966 + join_type
3967 + join.right._compiler_dispatch(
3968 self, asfrom=True, from_linter=from_linter, **kwargs
3969 )
3970 + " ON "
3971 # TODO: likely need asfrom=True here?
3972 + join.onclause._compiler_dispatch(
3973 self, from_linter=from_linter, **kwargs
3974 )
3975 )
3976
3977 def _setup_crud_hints(self, stmt, table_text):
3978 dialect_hints = dict(
3979 [
3980 (table, hint_text)
3981 for (table, dialect), hint_text in stmt._hints.items()
3982 if dialect in ("*", self.dialect.name)
3983 ]
3984 )
3985 if stmt.table in dialect_hints:
3986 table_text = self.format_from_hint_text(
3987 table_text, stmt.table, dialect_hints[stmt.table], True
3988 )
3989 return dialect_hints, table_text
3990
3991 def visit_insert(self, insert_stmt, **kw):
3992
3993 compile_state = insert_stmt._compile_state_factory(
3994 insert_stmt, self, **kw
3995 )
3996 insert_stmt = compile_state.statement
3997
3998 toplevel = not self.stack
3999
4000 if toplevel:
4001 self.isinsert = True
4002 if not self.dml_compile_state:
4003 self.dml_compile_state = compile_state
4004 if not self.compile_state:
4005 self.compile_state = compile_state
4006
4007 self.stack.append(
4008 {
4009 "correlate_froms": set(),
4010 "asfrom_froms": set(),
4011 "selectable": insert_stmt,
4012 }
4013 )
4014
4015 crud_params = crud._get_crud_params(
4016 self, insert_stmt, compile_state, **kw
4017 )
4018
4019 if (
4020 not crud_params
4021 and not self.dialect.supports_default_values
4022 and not self.dialect.supports_default_metavalue
4023 and not self.dialect.supports_empty_insert
4024 ):
4025 raise exc.CompileError(
4026 "The '%s' dialect with current database "
4027 "version settings does not support empty "
4028 "inserts." % self.dialect.name
4029 )
4030
4031 if compile_state._has_multi_parameters:
4032 if not self.dialect.supports_multivalues_insert:
4033 raise exc.CompileError(
4034 "The '%s' dialect with current database "
4035 "version settings does not support "
4036 "in-place multirow inserts." % self.dialect.name
4037 )
4038 crud_params_single = crud_params[0]
4039 else:
4040 crud_params_single = crud_params
4041
4042 preparer = self.preparer
4043 supports_default_values = self.dialect.supports_default_values
4044
4045 text = "INSERT "
4046
4047 if insert_stmt._prefixes:
4048 text += self._generate_prefixes(
4049 insert_stmt, insert_stmt._prefixes, **kw
4050 )
4051
4052 text += "INTO "
4053 table_text = preparer.format_table(insert_stmt.table)
4054
4055 if insert_stmt._hints:
4056 _, table_text = self._setup_crud_hints(insert_stmt, table_text)
4057
4058 if insert_stmt._independent_ctes:
4059 for cte in insert_stmt._independent_ctes:
4060 cte._compiler_dispatch(self, **kw)
4061
4062 text += table_text
4063
4064 if crud_params_single or not supports_default_values:
4065 text += " (%s)" % ", ".join(
4066 [expr for c, expr, value in crud_params_single]
4067 )
4068
4069 if self.returning or insert_stmt._returning:
4070 returning_clause = self.returning_clause(
4071 insert_stmt, self.returning or insert_stmt._returning
4072 )
4073
4074 if self.returning_precedes_values:
4075 text += " " + returning_clause
4076 else:
4077 returning_clause = None
4078
4079 if insert_stmt.select is not None:
4080 # placed here by crud.py
4081 select_text = self.process(
4082 self.stack[-1]["insert_from_select"], insert_into=True, **kw
4083 )
4084
4085 if self.ctes and self.dialect.cte_follows_insert:
4086 nesting_level = len(self.stack) if not toplevel else None
4087 text += " %s%s" % (
4088 self._render_cte_clause(
4089 nesting_level=nesting_level,
4090 include_following_stack=True,
4091 visiting_cte=kw.get("visiting_cte"),
4092 ),
4093 select_text,
4094 )
4095 else:
4096 text += " %s" % select_text
4097 elif not crud_params and supports_default_values:
4098 text += " DEFAULT VALUES"
4099 elif compile_state._has_multi_parameters:
4100 text += " VALUES %s" % (
4101 ", ".join(
4102 "(%s)"
4103 % (", ".join(value for c, expr, value in crud_param_set))
4104 for crud_param_set in crud_params
4105 )
4106 )
4107 else:
4108 insert_single_values_expr = ", ".join(
4109 [value for c, expr, value in crud_params]
4110 )
4111 text += " VALUES (%s)" % insert_single_values_expr
4112 if toplevel:
4113 self.insert_single_values_expr = insert_single_values_expr
4114
4115 if insert_stmt._post_values_clause is not None:
4116 post_values_clause = self.process(
4117 insert_stmt._post_values_clause, **kw
4118 )
4119 if post_values_clause:
4120 text += " " + post_values_clause
4121
4122 if returning_clause and not self.returning_precedes_values:
4123 text += " " + returning_clause
4124
4125 if self.ctes and not self.dialect.cte_follows_insert:
4126 nesting_level = len(self.stack) if not toplevel else None
4127 text = (
4128 self._render_cte_clause(
4129 nesting_level=nesting_level,
4130 include_following_stack=True,
4131 visiting_cte=kw.get("visiting_cte"),
4132 )
4133 + text
4134 )
4135
4136 self.stack.pop(-1)
4137
4138 return text
4139
4140 def update_limit_clause(self, update_stmt):
4141 """Provide a hook for MySQL to add LIMIT to the UPDATE"""
4142 return None
4143
4144 def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
4145 """Provide a hook to override the initial table clause
4146 in an UPDATE statement.
4147
4148 MySQL overrides this.
4149
4150 """
4151 kw["asfrom"] = True
4152 return from_table._compiler_dispatch(self, iscrud=True, **kw)
4153
4154 def update_from_clause(
4155 self, update_stmt, from_table, extra_froms, from_hints, **kw
4156 ):
4157 """Provide a hook to override the generation of an
4158 UPDATE..FROM clause.
4159
4160 MySQL and MSSQL override this.
4161
4162 """
4163 raise NotImplementedError(
4164 "This backend does not support multiple-table "
4165 "criteria within UPDATE"
4166 )
4167
4168 def visit_update(self, update_stmt, **kw):
4169 compile_state = update_stmt._compile_state_factory(
4170 update_stmt, self, **kw
4171 )
4172 update_stmt = compile_state.statement
4173
4174 toplevel = not self.stack
4175 if toplevel:
4176 self.isupdate = True
4177 if not self.dml_compile_state:
4178 self.dml_compile_state = compile_state
4179 if not self.compile_state:
4180 self.compile_state = compile_state
4181
4182 extra_froms = compile_state._extra_froms
4183 is_multitable = bool(extra_froms)
4184
4185 if is_multitable:
4186 # main table might be a JOIN
4187 main_froms = set(selectable._from_objects(update_stmt.table))
4188 render_extra_froms = [
4189 f for f in extra_froms if f not in main_froms
4190 ]
4191 correlate_froms = main_froms.union(extra_froms)
4192 else:
4193 render_extra_froms = []
4194 correlate_froms = {update_stmt.table}
4195
4196 self.stack.append(
4197 {
4198 "correlate_froms": correlate_froms,
4199 "asfrom_froms": correlate_froms,
4200 "selectable": update_stmt,
4201 }
4202 )
4203
4204 text = "UPDATE "
4205
4206 if update_stmt._prefixes:
4207 text += self._generate_prefixes(
4208 update_stmt, update_stmt._prefixes, **kw
4209 )
4210
4211 table_text = self.update_tables_clause(
4212 update_stmt, update_stmt.table, render_extra_froms, **kw
4213 )
4214 crud_params = crud._get_crud_params(
4215 self, update_stmt, compile_state, **kw
4216 )
4217
4218 if update_stmt._hints:
4219 dialect_hints, table_text = self._setup_crud_hints(
4220 update_stmt, table_text
4221 )
4222 else:
4223 dialect_hints = None
4224
4225 if update_stmt._independent_ctes:
4226 for cte in update_stmt._independent_ctes:
4227 cte._compiler_dispatch(self, **kw)
4228
4229 text += table_text
4230
4231 text += " SET "
4232 text += ", ".join(expr + "=" + value for c, expr, value in crud_params)
4233
4234 if self.returning or update_stmt._returning:
4235 if self.returning_precedes_values:
4236 text += " " + self.returning_clause(
4237 update_stmt, self.returning or update_stmt._returning
4238 )
4239
4240 if extra_froms:
4241 extra_from_text = self.update_from_clause(
4242 update_stmt,
4243 update_stmt.table,
4244 render_extra_froms,
4245 dialect_hints,
4246 **kw
4247 )
4248 if extra_from_text:
4249 text += " " + extra_from_text
4250
4251 if update_stmt._where_criteria:
4252 t = self._generate_delimited_and_list(
4253 update_stmt._where_criteria, **kw
4254 )
4255 if t:
4256 text += " WHERE " + t
4257
4258 limit_clause = self.update_limit_clause(update_stmt)
4259 if limit_clause:
4260 text += " " + limit_clause
4261
4262 if (
4263 self.returning or update_stmt._returning
4264 ) and not self.returning_precedes_values:
4265 text += " " + self.returning_clause(
4266 update_stmt, self.returning or update_stmt._returning
4267 )
4268
4269 if self.ctes:
4270 nesting_level = len(self.stack) if not toplevel else None
4271 text = (
4272 self._render_cte_clause(
4273 nesting_level=nesting_level,
4274 visiting_cte=kw.get("visiting_cte"),
4275 )
4276 + text
4277 )
4278
4279 self.stack.pop(-1)
4280
4281 return text
4282
4283 def delete_extra_from_clause(
4284 self, update_stmt, from_table, extra_froms, from_hints, **kw
4285 ):
4286 """Provide a hook to override the generation of an
4287 DELETE..FROM clause.
4288
4289 This can be used to implement DELETE..USING for example.
4290
4291 MySQL and MSSQL override this.
4292
4293 """
4294 raise NotImplementedError(
4295 "This backend does not support multiple-table "
4296 "criteria within DELETE"
4297 )
4298
4299 def delete_table_clause(self, delete_stmt, from_table, extra_froms):
4300 return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
4301
4302 def visit_delete(self, delete_stmt, **kw):
4303 compile_state = delete_stmt._compile_state_factory(
4304 delete_stmt, self, **kw
4305 )
4306 delete_stmt = compile_state.statement
4307
4308 toplevel = not self.stack
4309 if toplevel:
4310 self.isdelete = True
4311 if not self.dml_compile_state:
4312 self.dml_compile_state = compile_state
4313 if not self.compile_state:
4314 self.compile_state = compile_state
4315
4316 extra_froms = compile_state._extra_froms
4317
4318 correlate_froms = {delete_stmt.table}.union(extra_froms)
4319 self.stack.append(
4320 {
4321 "correlate_froms": correlate_froms,
4322 "asfrom_froms": correlate_froms,
4323 "selectable": delete_stmt,
4324 }
4325 )
4326
4327 text = "DELETE "
4328
4329 if delete_stmt._prefixes:
4330 text += self._generate_prefixes(
4331 delete_stmt, delete_stmt._prefixes, **kw
4332 )
4333
4334 text += "FROM "
4335 table_text = self.delete_table_clause(
4336 delete_stmt, delete_stmt.table, extra_froms
4337 )
4338
4339 if delete_stmt._hints:
4340 dialect_hints, table_text = self._setup_crud_hints(
4341 delete_stmt, table_text
4342 )
4343 else:
4344 dialect_hints = None
4345
4346 if delete_stmt._independent_ctes:
4347 for cte in delete_stmt._independent_ctes:
4348 cte._compiler_dispatch(self, **kw)
4349
4350 text += table_text
4351
4352 if delete_stmt._returning:
4353 if self.returning_precedes_values:
4354 text += " " + self.returning_clause(
4355 delete_stmt, delete_stmt._returning
4356 )
4357
4358 if extra_froms:
4359 extra_from_text = self.delete_extra_from_clause(
4360 delete_stmt,
4361 delete_stmt.table,
4362 extra_froms,
4363 dialect_hints,
4364 **kw
4365 )
4366 if extra_from_text:
4367 text += " " + extra_from_text
4368
4369 if delete_stmt._where_criteria:
4370 t = self._generate_delimited_and_list(
4371 delete_stmt._where_criteria, **kw
4372 )
4373 if t:
4374 text += " WHERE " + t
4375
4376 if delete_stmt._returning and not self.returning_precedes_values:
4377 text += " " + self.returning_clause(
4378 delete_stmt, delete_stmt._returning
4379 )
4380
4381 if self.ctes:
4382 nesting_level = len(self.stack) if not toplevel else None
4383 text = (
4384 self._render_cte_clause(
4385 nesting_level=nesting_level,
4386 visiting_cte=kw.get("visiting_cte"),
4387 )
4388 + text
4389 )
4390
4391 self.stack.pop(-1)
4392
4393 return text
4394
4395 def visit_savepoint(self, savepoint_stmt):
4396 return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
4397
4398 def visit_rollback_to_savepoint(self, savepoint_stmt):
4399 return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
4400 savepoint_stmt
4401 )
4402
4403 def visit_release_savepoint(self, savepoint_stmt):
4404 return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
4405 savepoint_stmt
4406 )
4407
4408
4409class StrSQLCompiler(SQLCompiler):
4410 """A :class:`.SQLCompiler` subclass which allows a small selection
4411 of non-standard SQL features to render into a string value.
4412
4413 The :class:`.StrSQLCompiler` is invoked whenever a Core expression
4414 element is directly stringified without calling upon the
4415 :meth:`_expression.ClauseElement.compile` method.
4416 It can render a limited set
4417 of non-standard SQL constructs to assist in basic stringification,
4418 however for more substantial custom or dialect-specific SQL constructs,
4419 it will be necessary to make use of
4420 :meth:`_expression.ClauseElement.compile`
4421 directly.
4422
4423 .. seealso::
4424
4425 :ref:`faq_sql_expression_string`
4426
4427 """
4428
4429 def _fallback_column_name(self, column):
4430 return "<name unknown>"
4431
4432 @util.preload_module("sqlalchemy.engine.url")
4433 def visit_unsupported_compilation(self, element, err, **kw):
4434 if element.stringify_dialect != "default":
4435 url = util.preloaded.engine_url
4436 dialect = url.URL.create(element.stringify_dialect).get_dialect()()
4437
4438 compiler = dialect.statement_compiler(dialect, None)
4439 if not isinstance(compiler, StrSQLCompiler):
4440 return compiler.process(element)
4441
4442 return super(StrSQLCompiler, self).visit_unsupported_compilation(
4443 element, err
4444 )
4445
4446 def visit_getitem_binary(self, binary, operator, **kw):
4447 return "%s[%s]" % (
4448 self.process(binary.left, **kw),
4449 self.process(binary.right, **kw),
4450 )
4451
4452 def visit_json_getitem_op_binary(self, binary, operator, **kw):
4453 return self.visit_getitem_binary(binary, operator, **kw)
4454
4455 def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
4456 return self.visit_getitem_binary(binary, operator, **kw)
4457
4458 def visit_sequence(self, seq, **kw):
4459 return "<next sequence value: %s>" % self.preparer.format_sequence(seq)
4460
4461 def returning_clause(self, stmt, returning_cols):
4462 columns = [
4463 self._label_select_column(
4464 None, c, True, False, {}, fallback_label_name=c._non_anon_label
4465 )
4466 for c in base._select_iterables(returning_cols)
4467 ]
4468
4469 return "RETURNING " + ", ".join(columns)
4470
4471 def update_from_clause(
4472 self, update_stmt, from_table, extra_froms, from_hints, **kw
4473 ):
4474 kw["asfrom"] = True
4475 return "FROM " + ", ".join(
4476 t._compiler_dispatch(self, fromhints=from_hints, **kw)
4477 for t in extra_froms
4478 )
4479
4480 def delete_extra_from_clause(
4481 self, update_stmt, from_table, extra_froms, from_hints, **kw
4482 ):
4483 kw["asfrom"] = True
4484 return ", " + ", ".join(
4485 t._compiler_dispatch(self, fromhints=from_hints, **kw)
4486 for t in extra_froms
4487 )
4488
4489 def visit_empty_set_expr(self, type_):
4490 return "SELECT 1 WHERE 1!=1"
4491
4492 def get_from_hint_text(self, table, text):
4493 return "[%s]" % text
4494
4495 def visit_regexp_match_op_binary(self, binary, operator, **kw):
4496 return self._generate_generic_binary(binary, " <regexp> ", **kw)
4497
4498 def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
4499 return self._generate_generic_binary(binary, " <not regexp> ", **kw)
4500
4501 def visit_regexp_replace_op_binary(self, binary, operator, **kw):
4502 return "<regexp replace>(%s, %s)" % (
4503 binary.left._compiler_dispatch(self, **kw),
4504 binary.right._compiler_dispatch(self, **kw),
4505 )
4506
4507
4508class DDLCompiler(Compiled):
4509 @util.memoized_property
4510 def sql_compiler(self):
4511 return self.dialect.statement_compiler(
4512 self.dialect, None, schema_translate_map=self.schema_translate_map
4513 )
4514
4515 @util.memoized_property
4516 def type_compiler(self):
4517 return self.dialect.type_compiler
4518
4519 def construct_params(
4520 self, params=None, extracted_parameters=None, escape_names=True
4521 ):
4522 return None
4523
4524 def visit_ddl(self, ddl, **kwargs):
4525 # table events can substitute table and schema name
4526 context = ddl.context
4527 if isinstance(ddl.target, schema.Table):
4528 context = context.copy()
4529
4530 preparer = self.preparer
4531 path = preparer.format_table_seq(ddl.target)
4532 if len(path) == 1:
4533 table, sch = path[0], ""
4534 else:
4535 table, sch = path[-1], path[0]
4536
4537 context.setdefault("table", table)
4538 context.setdefault("schema", sch)
4539 context.setdefault("fullname", preparer.format_table(ddl.target))
4540
4541 return self.sql_compiler.post_process_text(ddl.statement % context)
4542
4543 def visit_create_schema(self, create, **kw):
4544 schema = self.preparer.format_schema(create.element)
4545 return "CREATE SCHEMA " + schema
4546
4547 def visit_drop_schema(self, drop, **kw):
4548 schema = self.preparer.format_schema(drop.element)
4549 text = "DROP SCHEMA " + schema
4550 if drop.cascade:
4551 text += " CASCADE"
4552 return text
4553
4554 def visit_create_table(self, create, **kw):
4555 table = create.element
4556 preparer = self.preparer
4557
4558 text = "\nCREATE "
4559 if table._prefixes:
4560 text += " ".join(table._prefixes) + " "
4561
4562 text += "TABLE "
4563 if create.if_not_exists:
4564 text += "IF NOT EXISTS "
4565
4566 text += preparer.format_table(table) + " "
4567
4568 create_table_suffix = self.create_table_suffix(table)
4569 if create_table_suffix:
4570 text += create_table_suffix + " "
4571
4572 text += "("
4573
4574 separator = "\n"
4575
4576 # if only one primary key, specify it along with the column
4577 first_pk = False
4578 for create_column in create.columns:
4579 column = create_column.element
4580 try:
4581 processed = self.process(
4582 create_column, first_pk=column.primary_key and not first_pk
4583 )
4584 if processed is not None:
4585 text += separator
4586 separator = ", \n"
4587 text += "\t" + processed
4588 if column.primary_key:
4589 first_pk = True
4590 except exc.CompileError as ce:
4591 util.raise_(
4592 exc.CompileError(
4593 util.u("(in table '%s', column '%s'): %s")
4594 % (table.description, column.name, ce.args[0])
4595 ),
4596 from_=ce,
4597 )
4598
4599 const = self.create_table_constraints(
4600 table,
4601 _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
4602 )
4603 if const:
4604 text += separator + "\t" + const
4605
4606 text += "\n)%s\n\n" % self.post_create_table(table)
4607 return text
4608
4609 def visit_create_column(self, create, first_pk=False, **kw):
4610 column = create.element
4611
4612 if column.system:
4613 return None
4614
4615 text = self.get_column_specification(column, first_pk=first_pk)
4616 const = " ".join(
4617 self.process(constraint) for constraint in column.constraints
4618 )
4619 if const:
4620 text += " " + const
4621
4622 return text
4623
4624 def create_table_constraints(
4625 self, table, _include_foreign_key_constraints=None, **kw
4626 ):
4627
4628 # On some DB order is significant: visit PK first, then the
4629 # other constraints (engine.ReflectionTest.testbasic failed on FB2)
4630 constraints = []
4631 if table.primary_key:
4632 constraints.append(table.primary_key)
4633
4634 all_fkcs = table.foreign_key_constraints
4635 if _include_foreign_key_constraints is not None:
4636 omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints)
4637 else:
4638 omit_fkcs = set()
4639
4640 constraints.extend(
4641 [
4642 c
4643 for c in table._sorted_constraints
4644 if c is not table.primary_key and c not in omit_fkcs
4645 ]
4646 )
4647
4648 return ", \n\t".join(
4649 p
4650 for p in (
4651 self.process(constraint)
4652 for constraint in constraints
4653 if (
4654 constraint._create_rule is None
4655 or constraint._create_rule(self)
4656 )
4657 and (
4658 not self.dialect.supports_alter
4659 or not getattr(constraint, "use_alter", False)
4660 )
4661 )
4662 if p is not None
4663 )
4664
4665 def visit_drop_table(self, drop, **kw):
4666 text = "\nDROP TABLE "
4667 if drop.if_exists:
4668 text += "IF EXISTS "
4669 return text + self.preparer.format_table(drop.element)
4670
4671 def visit_drop_view(self, drop, **kw):
4672 return "\nDROP VIEW " + self.preparer.format_table(drop.element)
4673
4674 def _verify_index_table(self, index):
4675 if index.table is None:
4676 raise exc.CompileError(
4677 "Index '%s' is not associated " "with any table." % index.name
4678 )
4679
4680 def visit_create_index(
4681 self, create, include_schema=False, include_table_schema=True, **kw
4682 ):
4683 index = create.element
4684 self._verify_index_table(index)
4685 preparer = self.preparer
4686 text = "CREATE "
4687 if index.unique:
4688 text += "UNIQUE "
4689 if index.name is None:
4690 raise exc.CompileError(
4691 "CREATE INDEX requires that the index have a name"
4692 )
4693
4694 text += "INDEX "
4695 if create.if_not_exists:
4696 text += "IF NOT EXISTS "
4697
4698 text += "%s ON %s (%s)" % (
4699 self._prepared_index_name(index, include_schema=include_schema),
4700 preparer.format_table(
4701 index.table, use_schema=include_table_schema
4702 ),
4703 ", ".join(
4704 self.sql_compiler.process(
4705 expr, include_table=False, literal_binds=True
4706 )
4707 for expr in index.expressions
4708 ),
4709 )
4710 return text
4711
4712 def visit_drop_index(self, drop, **kw):
4713 index = drop.element
4714
4715 if index.name is None:
4716 raise exc.CompileError(
4717 "DROP INDEX requires that the index have a name"
4718 )
4719 text = "\nDROP INDEX "
4720 if drop.if_exists:
4721 text += "IF EXISTS "
4722
4723 return text + self._prepared_index_name(index, include_schema=True)
4724
4725 def _prepared_index_name(self, index, include_schema=False):
4726 if index.table is not None:
4727 effective_schema = self.preparer.schema_for_object(index.table)
4728 else:
4729 effective_schema = None
4730 if include_schema and effective_schema:
4731 schema_name = self.preparer.quote_schema(effective_schema)
4732 else:
4733 schema_name = None
4734
4735 index_name = self.preparer.format_index(index)
4736
4737 if schema_name:
4738 index_name = schema_name + "." + index_name
4739 return index_name
4740
4741 def visit_add_constraint(self, create, **kw):
4742 return "ALTER TABLE %s ADD %s" % (
4743 self.preparer.format_table(create.element.table),
4744 self.process(create.element),
4745 )
4746
4747 def visit_set_table_comment(self, create, **kw):
4748 return "COMMENT ON TABLE %s IS %s" % (
4749 self.preparer.format_table(create.element),
4750 self.sql_compiler.render_literal_value(
4751 create.element.comment, sqltypes.String()
4752 ),
4753 )
4754
4755 def visit_drop_table_comment(self, drop, **kw):
4756 return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
4757 drop.element
4758 )
4759
4760 def visit_set_column_comment(self, create, **kw):
4761 return "COMMENT ON COLUMN %s IS %s" % (
4762 self.preparer.format_column(
4763 create.element, use_table=True, use_schema=True
4764 ),
4765 self.sql_compiler.render_literal_value(
4766 create.element.comment, sqltypes.String()
4767 ),
4768 )
4769
4770 def visit_drop_column_comment(self, drop, **kw):
4771 return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
4772 drop.element, use_table=True
4773 )
4774
4775 def get_identity_options(self, identity_options):
4776 text = []
4777 if identity_options.increment is not None:
4778 text.append("INCREMENT BY %d" % identity_options.increment)
4779 if identity_options.start is not None:
4780 text.append("START WITH %d" % identity_options.start)
4781 if identity_options.minvalue is not None:
4782 text.append("MINVALUE %d" % identity_options.minvalue)
4783 if identity_options.maxvalue is not None:
4784 text.append("MAXVALUE %d" % identity_options.maxvalue)
4785 if identity_options.nominvalue is not None:
4786 text.append("NO MINVALUE")
4787 if identity_options.nomaxvalue is not None:
4788 text.append("NO MAXVALUE")
4789 if identity_options.cache is not None:
4790 text.append("CACHE %d" % identity_options.cache)
4791 if identity_options.cycle is not None:
4792 text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
4793 return " ".join(text)
4794
4795 def visit_create_sequence(self, create, prefix=None, **kw):
4796 text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(
4797 create.element
4798 )
4799 if prefix:
4800 text += prefix
4801 if create.element.start is None:
4802 create.element.start = self.dialect.default_sequence_base
4803 options = self.get_identity_options(create.element)
4804 if options:
4805 text += " " + options
4806 return text
4807
4808 def visit_drop_sequence(self, drop, **kw):
4809 return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
4810
4811 def visit_drop_constraint(self, drop, **kw):
4812 constraint = drop.element
4813 if constraint.name is not None:
4814 formatted_name = self.preparer.format_constraint(constraint)
4815 else:
4816 formatted_name = None
4817
4818 if formatted_name is None:
4819 raise exc.CompileError(
4820 "Can't emit DROP CONSTRAINT for constraint %r; "
4821 "it has no name" % drop.element
4822 )
4823 return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
4824 self.preparer.format_table(drop.element.table),
4825 formatted_name,
4826 drop.cascade and " CASCADE" or "",
4827 )
4828
4829 def get_column_specification(self, column, **kwargs):
4830 colspec = (
4831 self.preparer.format_column(column)
4832 + " "
4833 + self.dialect.type_compiler.process(
4834 column.type, type_expression=column
4835 )
4836 )
4837 default = self.get_column_default_string(column)
4838 if default is not None:
4839 colspec += " DEFAULT " + default
4840
4841 if column.computed is not None:
4842 colspec += " " + self.process(column.computed)
4843
4844 if (
4845 column.identity is not None
4846 and self.dialect.supports_identity_columns
4847 ):
4848 colspec += " " + self.process(column.identity)
4849
4850 if not column.nullable and (
4851 not column.identity or not self.dialect.supports_identity_columns
4852 ):
4853 colspec += " NOT NULL"
4854 return colspec
4855
4856 def create_table_suffix(self, table):
4857 return ""
4858
4859 def post_create_table(self, table):
4860 return ""
4861
4862 def get_column_default_string(self, column):
4863 if isinstance(column.server_default, schema.DefaultClause):
4864 if isinstance(column.server_default.arg, util.string_types):
4865 return self.sql_compiler.render_literal_value(
4866 column.server_default.arg, sqltypes.STRINGTYPE
4867 )
4868 else:
4869 return self.sql_compiler.process(
4870 column.server_default.arg, literal_binds=True
4871 )
4872 else:
4873 return None
4874
4875 def visit_table_or_column_check_constraint(self, constraint, **kw):
4876 if constraint.is_column_level:
4877 return self.visit_column_check_constraint(constraint)
4878 else:
4879 return self.visit_check_constraint(constraint)
4880
4881 def visit_check_constraint(self, constraint, **kw):
4882 text = ""
4883 if constraint.name is not None:
4884 formatted_name = self.preparer.format_constraint(constraint)
4885 if formatted_name is not None:
4886 text += "CONSTRAINT %s " % formatted_name
4887 text += "CHECK (%s)" % self.sql_compiler.process(
4888 constraint.sqltext, include_table=False, literal_binds=True
4889 )
4890 text += self.define_constraint_deferrability(constraint)
4891 return text
4892
4893 def visit_column_check_constraint(self, constraint, **kw):
4894 text = ""
4895 if constraint.name is not None:
4896 formatted_name = self.preparer.format_constraint(constraint)
4897 if formatted_name is not None:
4898 text += "CONSTRAINT %s " % formatted_name
4899 text += "CHECK (%s)" % self.sql_compiler.process(
4900 constraint.sqltext, include_table=False, literal_binds=True
4901 )
4902 text += self.define_constraint_deferrability(constraint)
4903 return text
4904
4905 def visit_primary_key_constraint(self, constraint, **kw):
4906 if len(constraint) == 0:
4907 return ""
4908 text = ""
4909 if constraint.name is not None:
4910 formatted_name = self.preparer.format_constraint(constraint)
4911 if formatted_name is not None:
4912 text += "CONSTRAINT %s " % formatted_name
4913 text += "PRIMARY KEY "
4914 text += "(%s)" % ", ".join(
4915 self.preparer.quote(c.name)
4916 for c in (
4917 constraint.columns_autoinc_first
4918 if constraint._implicit_generated
4919 else constraint.columns
4920 )
4921 )
4922 text += self.define_constraint_deferrability(constraint)
4923 return text
4924
4925 def visit_foreign_key_constraint(self, constraint, **kw):
4926 preparer = self.preparer
4927 text = ""
4928 if constraint.name is not None:
4929 formatted_name = self.preparer.format_constraint(constraint)
4930 if formatted_name is not None:
4931 text += "CONSTRAINT %s " % formatted_name
4932 remote_table = list(constraint.elements)[0].column.table
4933 text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
4934 ", ".join(
4935 preparer.quote(f.parent.name) for f in constraint.elements
4936 ),
4937 self.define_constraint_remote_table(
4938 constraint, remote_table, preparer
4939 ),
4940 ", ".join(
4941 preparer.quote(f.column.name) for f in constraint.elements
4942 ),
4943 )
4944 text += self.define_constraint_match(constraint)
4945 text += self.define_constraint_cascades(constraint)
4946 text += self.define_constraint_deferrability(constraint)
4947 return text
4948
4949 def define_constraint_remote_table(self, constraint, table, preparer):
4950 """Format the remote table clause of a CREATE CONSTRAINT clause."""
4951
4952 return preparer.format_table(table)
4953
4954 def visit_unique_constraint(self, constraint, **kw):
4955 if len(constraint) == 0:
4956 return ""
4957 text = ""
4958 if constraint.name is not None:
4959 formatted_name = self.preparer.format_constraint(constraint)
4960 if formatted_name is not None:
4961 text += "CONSTRAINT %s " % formatted_name
4962 text += "UNIQUE (%s)" % (
4963 ", ".join(self.preparer.quote(c.name) for c in constraint)
4964 )
4965 text += self.define_constraint_deferrability(constraint)
4966 return text
4967
4968 def define_constraint_cascades(self, constraint):
4969 text = ""
4970 if constraint.ondelete is not None:
4971 text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
4972 constraint.ondelete, FK_ON_DELETE
4973 )
4974 if constraint.onupdate is not None:
4975 text += " ON UPDATE %s" % self.preparer.validate_sql_phrase(
4976 constraint.onupdate, FK_ON_UPDATE
4977 )
4978 return text
4979
4980 def define_constraint_deferrability(self, constraint):
4981 text = ""
4982 if constraint.deferrable is not None:
4983 if constraint.deferrable:
4984 text += " DEFERRABLE"
4985 else:
4986 text += " NOT DEFERRABLE"
4987 if constraint.initially is not None:
4988 text += " INITIALLY %s" % self.preparer.validate_sql_phrase(
4989 constraint.initially, FK_INITIALLY
4990 )
4991 return text
4992
4993 def define_constraint_match(self, constraint):
4994 text = ""
4995 if constraint.match is not None:
4996 text += " MATCH %s" % constraint.match
4997 return text
4998
4999 def visit_computed_column(self, generated, **kw):
5000 text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
5001 generated.sqltext, include_table=False, literal_binds=True
5002 )
5003 if generated.persisted is True:
5004 text += " STORED"
5005 elif generated.persisted is False:
5006 text += " VIRTUAL"
5007 return text
5008
5009 def visit_identity_column(self, identity, **kw):
5010 text = "GENERATED %s AS IDENTITY" % (
5011 "ALWAYS" if identity.always else "BY DEFAULT",
5012 )
5013 options = self.get_identity_options(identity)
5014 if options:
5015 text += " (%s)" % options
5016 return text
5017
5018
5019class GenericTypeCompiler(TypeCompiler):
5020 def visit_FLOAT(self, type_, **kw):
5021 return "FLOAT"
5022
5023 def visit_REAL(self, type_, **kw):
5024 return "REAL"
5025
5026 def visit_NUMERIC(self, type_, **kw):
5027 if type_.precision is None:
5028 return "NUMERIC"
5029 elif type_.scale is None:
5030 return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
5031 else:
5032 return "NUMERIC(%(precision)s, %(scale)s)" % {
5033 "precision": type_.precision,
5034 "scale": type_.scale,
5035 }
5036
5037 def visit_DECIMAL(self, type_, **kw):
5038 if type_.precision is None:
5039 return "DECIMAL"
5040 elif type_.scale is None:
5041 return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
5042 else:
5043 return "DECIMAL(%(precision)s, %(scale)s)" % {
5044 "precision": type_.precision,
5045 "scale": type_.scale,
5046 }
5047
5048 def visit_INTEGER(self, type_, **kw):
5049 return "INTEGER"
5050
5051 def visit_SMALLINT(self, type_, **kw):
5052 return "SMALLINT"
5053
5054 def visit_BIGINT(self, type_, **kw):
5055 return "BIGINT"
5056
5057 def visit_TIMESTAMP(self, type_, **kw):
5058 return "TIMESTAMP"
5059
5060 def visit_DATETIME(self, type_, **kw):
5061 return "DATETIME"
5062
5063 def visit_DATE(self, type_, **kw):
5064 return "DATE"
5065
5066 def visit_TIME(self, type_, **kw):
5067 return "TIME"
5068
5069 def visit_CLOB(self, type_, **kw):
5070 return "CLOB"
5071
5072 def visit_NCLOB(self, type_, **kw):
5073 return "NCLOB"
5074
5075 def _render_string_type(self, type_, name):
5076
5077 text = name
5078 if type_.length:
5079 text += "(%d)" % type_.length
5080 if type_.collation:
5081 text += ' COLLATE "%s"' % type_.collation
5082 return text
5083
5084 def visit_CHAR(self, type_, **kw):
5085 return self._render_string_type(type_, "CHAR")
5086
5087 def visit_NCHAR(self, type_, **kw):
5088 return self._render_string_type(type_, "NCHAR")
5089
5090 def visit_VARCHAR(self, type_, **kw):
5091 return self._render_string_type(type_, "VARCHAR")
5092
5093 def visit_NVARCHAR(self, type_, **kw):
5094 return self._render_string_type(type_, "NVARCHAR")
5095
5096 def visit_TEXT(self, type_, **kw):
5097 return self._render_string_type(type_, "TEXT")
5098
5099 def visit_BLOB(self, type_, **kw):
5100 return "BLOB"
5101
5102 def visit_BINARY(self, type_, **kw):
5103 return "BINARY" + (type_.length and "(%d)" % type_.length or "")
5104
5105 def visit_VARBINARY(self, type_, **kw):
5106 return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
5107
5108 def visit_BOOLEAN(self, type_, **kw):
5109 return "BOOLEAN"
5110
5111 def visit_large_binary(self, type_, **kw):
5112 return self.visit_BLOB(type_, **kw)
5113
5114 def visit_boolean(self, type_, **kw):
5115 return self.visit_BOOLEAN(type_, **kw)
5116
5117 def visit_time(self, type_, **kw):
5118 return self.visit_TIME(type_, **kw)
5119
5120 def visit_datetime(self, type_, **kw):
5121 return self.visit_DATETIME(type_, **kw)
5122
5123 def visit_date(self, type_, **kw):
5124 return self.visit_DATE(type_, **kw)
5125
5126 def visit_big_integer(self, type_, **kw):
5127 return self.visit_BIGINT(type_, **kw)
5128
5129 def visit_small_integer(self, type_, **kw):
5130 return self.visit_SMALLINT(type_, **kw)
5131
5132 def visit_integer(self, type_, **kw):
5133 return self.visit_INTEGER(type_, **kw)
5134
5135 def visit_real(self, type_, **kw):
5136 return self.visit_REAL(type_, **kw)
5137
5138 def visit_float(self, type_, **kw):
5139 return self.visit_FLOAT(type_, **kw)
5140
5141 def visit_numeric(self, type_, **kw):
5142 return self.visit_NUMERIC(type_, **kw)
5143
5144 def visit_string(self, type_, **kw):
5145 return self.visit_VARCHAR(type_, **kw)
5146
5147 def visit_unicode(self, type_, **kw):
5148 return self.visit_VARCHAR(type_, **kw)
5149
5150 def visit_text(self, type_, **kw):
5151 return self.visit_TEXT(type_, **kw)
5152
5153 def visit_unicode_text(self, type_, **kw):
5154 return self.visit_TEXT(type_, **kw)
5155
5156 def visit_enum(self, type_, **kw):
5157 return self.visit_VARCHAR(type_, **kw)
5158
5159 def visit_null(self, type_, **kw):
5160 raise exc.CompileError(
5161 "Can't generate DDL for %r; "
5162 "did you forget to specify a "
5163 "type on this Column?" % type_
5164 )
5165
5166 def visit_type_decorator(self, type_, **kw):
5167 return self.process(type_.type_engine(self.dialect), **kw)
5168
5169 def visit_user_defined(self, type_, **kw):
5170 return type_.get_col_spec(**kw)
5171
5172
5173class StrSQLTypeCompiler(GenericTypeCompiler):
5174 def process(self, type_, **kw):
5175 try:
5176 _compiler_dispatch = type_._compiler_dispatch
5177 except AttributeError:
5178 return self._visit_unknown(type_, **kw)
5179 else:
5180 return _compiler_dispatch(self, **kw)
5181
5182 def __getattr__(self, key):
5183 if key.startswith("visit_"):
5184 return self._visit_unknown
5185 else:
5186 raise AttributeError(key)
5187
5188 def _visit_unknown(self, type_, **kw):
5189 if type_.__class__.__name__ == type_.__class__.__name__.upper():
5190 return type_.__class__.__name__
5191 else:
5192 return repr(type_)
5193
5194 def visit_null(self, type_, **kw):
5195 return "NULL"
5196
5197 def visit_user_defined(self, type_, **kw):
5198 try:
5199 get_col_spec = type_.get_col_spec
5200 except AttributeError:
5201 return repr(type_)
5202 else:
5203 return get_col_spec(**kw)
5204
5205
5206class IdentifierPreparer(object):
5207
5208 """Handle quoting and case-folding of identifiers based on options."""
5209
5210 reserved_words = RESERVED_WORDS
5211
5212 legal_characters = LEGAL_CHARACTERS
5213
5214 illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
5215
5216 schema_for_object = operator.attrgetter("schema")
5217 """Return the .schema attribute for an object.
5218
5219 For the default IdentifierPreparer, the schema for an object is always
5220 the value of the ".schema" attribute. if the preparer is replaced
5221 with one that has a non-empty schema_translate_map, the value of the
5222 ".schema" attribute is rendered a symbol that will be converted to a
5223 real schema name from the mapping post-compile.
5224
5225 """
5226
5227 def __init__(
5228 self,
5229 dialect,
5230 initial_quote='"',
5231 final_quote=None,
5232 escape_quote='"',
5233 quote_case_sensitive_collations=True,
5234 omit_schema=False,
5235 ):
5236 """Construct a new ``IdentifierPreparer`` object.
5237
5238 initial_quote
5239 Character that begins a delimited identifier.
5240
5241 final_quote
5242 Character that ends a delimited identifier. Defaults to
5243 `initial_quote`.
5244
5245 omit_schema
5246 Prevent prepending schema name. Useful for databases that do
5247 not support schemae.
5248 """
5249
5250 self.dialect = dialect
5251 self.initial_quote = initial_quote
5252 self.final_quote = final_quote or self.initial_quote
5253 self.escape_quote = escape_quote
5254 self.escape_to_quote = self.escape_quote * 2
5255 self.omit_schema = omit_schema
5256 self.quote_case_sensitive_collations = quote_case_sensitive_collations
5257 self._strings = {}
5258 self._double_percents = self.dialect.paramstyle in (
5259 "format",
5260 "pyformat",
5261 )
5262
5263 def _with_schema_translate(self, schema_translate_map):
5264 prep = self.__class__.__new__(self.__class__)
5265 prep.__dict__.update(self.__dict__)
5266
5267 def symbol_getter(obj):
5268 name = obj.schema
5269 if name in schema_translate_map and obj._use_schema_map:
5270 if name is not None and ("[" in name or "]" in name):
5271 raise exc.CompileError(
5272 "Square bracket characters ([]) not supported "
5273 "in schema translate name '%s'" % name
5274 )
5275 return quoted_name(
5276 "__[SCHEMA_%s]" % (name or "_none"), quote=False
5277 )
5278 else:
5279 return obj.schema
5280
5281 prep.schema_for_object = symbol_getter
5282 return prep
5283
5284 def _render_schema_translates(self, statement, schema_translate_map):
5285 d = schema_translate_map
5286 if None in d:
5287 d["_none"] = d[None]
5288
5289 def replace(m):
5290 name = m.group(2)
5291 effective_schema = d[name]
5292 if not effective_schema:
5293 effective_schema = self.dialect.default_schema_name
5294 if not effective_schema:
5295 # TODO: no coverage here
5296 raise exc.CompileError(
5297 "Dialect has no default schema name; can't "
5298 "use None as dynamic schema target."
5299 )
5300 return self.quote_schema(effective_schema)
5301
5302 return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
5303
5304 def _escape_identifier(self, value):
5305 """Escape an identifier.
5306
5307 Subclasses should override this to provide database-dependent
5308 escaping behavior.
5309 """
5310
5311 value = value.replace(self.escape_quote, self.escape_to_quote)
5312 if self._double_percents:
5313 value = value.replace("%", "%%")
5314 return value
5315
5316 def _unescape_identifier(self, value):
5317 """Canonicalize an escaped identifier.
5318
5319 Subclasses should override this to provide database-dependent
5320 unescaping behavior that reverses _escape_identifier.
5321 """
5322
5323 return value.replace(self.escape_to_quote, self.escape_quote)
5324
5325 def validate_sql_phrase(self, element, reg):
5326 """keyword sequence filter.
5327
5328 a filter for elements that are intended to represent keyword sequences,
5329 such as "INITIALLY", "INITIALLY DEFERRED", etc. no special characters
5330 should be present.
5331
5332 .. versionadded:: 1.3
5333
5334 """
5335
5336 if element is not None and not reg.match(element):
5337 raise exc.CompileError(
5338 "Unexpected SQL phrase: %r (matching against %r)"
5339 % (element, reg.pattern)
5340 )
5341 return element
5342
5343 def quote_identifier(self, value):
5344 """Quote an identifier.
5345
5346 Subclasses should override this to provide database-dependent
5347 quoting behavior.
5348 """
5349
5350 return (
5351 self.initial_quote
5352 + self._escape_identifier(value)
5353 + self.final_quote
5354 )
5355
5356 def _requires_quotes(self, value):
5357 """Return True if the given identifier requires quoting."""
5358 lc_value = value.lower()
5359 return (
5360 lc_value in self.reserved_words
5361 or value[0] in self.illegal_initial_characters
5362 or not self.legal_characters.match(util.text_type(value))
5363 or (lc_value != value)
5364 )
5365
5366 def _requires_quotes_illegal_chars(self, value):
5367 """Return True if the given identifier requires quoting, but
5368 not taking case convention into account."""
5369 return not self.legal_characters.match(util.text_type(value))
5370
5371 def quote_schema(self, schema, force=None):
5372 """Conditionally quote a schema name.
5373
5374
5375 The name is quoted if it is a reserved word, contains quote-necessary
5376 characters, or is an instance of :class:`.quoted_name` which includes
5377 ``quote`` set to ``True``.
5378
5379 Subclasses can override this to provide database-dependent
5380 quoting behavior for schema names.
5381
5382 :param schema: string schema name
5383 :param force: unused
5384
5385 .. deprecated:: 0.9
5386
5387 The :paramref:`.IdentifierPreparer.quote_schema.force`
5388 parameter is deprecated and will be removed in a future
5389 release. This flag has no effect on the behavior of the
5390 :meth:`.IdentifierPreparer.quote` method; please refer to
5391 :class:`.quoted_name`.
5392
5393 """
5394 if force is not None:
5395 # not using the util.deprecated_params() decorator in this
5396 # case because of the additional function call overhead on this
5397 # very performance-critical spot.
5398 util.warn_deprecated(
5399 "The IdentifierPreparer.quote_schema.force parameter is "
5400 "deprecated and will be removed in a future release. This "
5401 "flag has no effect on the behavior of the "
5402 "IdentifierPreparer.quote method; please refer to "
5403 "quoted_name().",
5404 # deprecated 0.9. warning from 1.3
5405 version="0.9",
5406 )
5407
5408 return self.quote(schema)
5409
5410 def quote(self, ident, force=None):
5411 """Conditionally quote an identifier.
5412
5413 The identifier is quoted if it is a reserved word, contains
5414 quote-necessary characters, or is an instance of
5415 :class:`.quoted_name` which includes ``quote`` set to ``True``.
5416
5417 Subclasses can override this to provide database-dependent
5418 quoting behavior for identifier names.
5419
5420 :param ident: string identifier
5421 :param force: unused
5422
5423 .. deprecated:: 0.9
5424
5425 The :paramref:`.IdentifierPreparer.quote.force`
5426 parameter is deprecated and will be removed in a future
5427 release. This flag has no effect on the behavior of the
5428 :meth:`.IdentifierPreparer.quote` method; please refer to
5429 :class:`.quoted_name`.
5430
5431 """
5432 if force is not None:
5433 # not using the util.deprecated_params() decorator in this
5434 # case because of the additional function call overhead on this
5435 # very performance-critical spot.
5436 util.warn_deprecated(
5437 "The IdentifierPreparer.quote.force parameter is "
5438 "deprecated and will be removed in a future release. This "
5439 "flag has no effect on the behavior of the "
5440 "IdentifierPreparer.quote method; please refer to "
5441 "quoted_name().",
5442 # deprecated 0.9. warning from 1.3
5443 version="0.9",
5444 )
5445
5446 force = getattr(ident, "quote", None)
5447
5448 if force is None:
5449 if ident in self._strings:
5450 return self._strings[ident]
5451 else:
5452 if self._requires_quotes(ident):
5453 self._strings[ident] = self.quote_identifier(ident)
5454 else:
5455 self._strings[ident] = ident
5456 return self._strings[ident]
5457 elif force:
5458 return self.quote_identifier(ident)
5459 else:
5460 return ident
5461
5462 def format_collation(self, collation_name):
5463 if self.quote_case_sensitive_collations:
5464 return self.quote(collation_name)
5465 else:
5466 return collation_name
5467
5468 def format_sequence(self, sequence, use_schema=True):
5469 name = self.quote(sequence.name)
5470
5471 effective_schema = self.schema_for_object(sequence)
5472
5473 if (
5474 not self.omit_schema
5475 and use_schema
5476 and effective_schema is not None
5477 ):
5478 name = self.quote_schema(effective_schema) + "." + name
5479 return name
5480
5481 def format_label(self, label, name=None):
5482 return self.quote(name or label.name)
5483
5484 def format_alias(self, alias, name=None):
5485 return self.quote(name or alias.name)
5486
5487 def format_savepoint(self, savepoint, name=None):
5488 # Running the savepoint name through quoting is unnecessary
5489 # for all known dialects. This is here to support potential
5490 # third party use cases
5491 ident = name or savepoint.ident
5492 if self._requires_quotes(ident):
5493 ident = self.quote_identifier(ident)
5494 return ident
5495
5496 @util.preload_module("sqlalchemy.sql.naming")
5497 def format_constraint(self, constraint, _alembic_quote=True):
5498 naming = util.preloaded.sql_naming
5499
5500 if constraint.name is elements._NONE_NAME:
5501 name = naming._constraint_name_for_table(
5502 constraint, constraint.table
5503 )
5504
5505 if name is None:
5506 return None
5507 else:
5508 name = constraint.name
5509
5510 if constraint.__visit_name__ == "index":
5511 return self.truncate_and_render_index_name(
5512 name, _alembic_quote=_alembic_quote
5513 )
5514 else:
5515 return self.truncate_and_render_constraint_name(
5516 name, _alembic_quote=_alembic_quote
5517 )
5518
5519 def truncate_and_render_index_name(self, name, _alembic_quote=True):
5520 # calculate these at format time so that ad-hoc changes
5521 # to dialect.max_identifier_length etc. can be reflected
5522 # as IdentifierPreparer is long lived
5523 max_ = (
5524 self.dialect.max_index_name_length
5525 or self.dialect.max_identifier_length
5526 )
5527 return self._truncate_and_render_maxlen_name(
5528 name, max_, _alembic_quote
5529 )
5530
5531 def truncate_and_render_constraint_name(self, name, _alembic_quote=True):
5532 # calculate these at format time so that ad-hoc changes
5533 # to dialect.max_identifier_length etc. can be reflected
5534 # as IdentifierPreparer is long lived
5535 max_ = (
5536 self.dialect.max_constraint_name_length
5537 or self.dialect.max_identifier_length
5538 )
5539 return self._truncate_and_render_maxlen_name(
5540 name, max_, _alembic_quote
5541 )
5542
5543 def _truncate_and_render_maxlen_name(self, name, max_, _alembic_quote):
5544 if isinstance(name, elements._truncated_label):
5545 if len(name) > max_:
5546 name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
5547 else:
5548 self.dialect.validate_identifier(name)
5549
5550 if not _alembic_quote:
5551 return name
5552 else:
5553 return self.quote(name)
5554
5555 def format_index(self, index):
5556 return self.format_constraint(index)
5557
5558 def format_table(self, table, use_schema=True, name=None):
5559 """Prepare a quoted table and schema name."""
5560
5561 if name is None:
5562 name = table.name
5563
5564 result = self.quote(name)
5565
5566 effective_schema = self.schema_for_object(table)
5567
5568 if not self.omit_schema and use_schema and effective_schema:
5569 result = self.quote_schema(effective_schema) + "." + result
5570 return result
5571
5572 def format_schema(self, name):
5573 """Prepare a quoted schema name."""
5574
5575 return self.quote(name)
5576
5577 def format_label_name(
5578 self,
5579 name,
5580 anon_map=None,
5581 ):
5582 """Prepare a quoted column name."""
5583
5584 if anon_map is not None and isinstance(
5585 name, elements._truncated_label
5586 ):
5587 name = name.apply_map(anon_map)
5588
5589 return self.quote(name)
5590
5591 def format_column(
5592 self,
5593 column,
5594 use_table=False,
5595 name=None,
5596 table_name=None,
5597 use_schema=False,
5598 anon_map=None,
5599 ):
5600 """Prepare a quoted column name."""
5601
5602 if name is None:
5603 name = column.name
5604
5605 if anon_map is not None and isinstance(
5606 name, elements._truncated_label
5607 ):
5608 name = name.apply_map(anon_map)
5609
5610 if not getattr(column, "is_literal", False):
5611 if use_table:
5612 return (
5613 self.format_table(
5614 column.table, use_schema=use_schema, name=table_name
5615 )
5616 + "."
5617 + self.quote(name)
5618 )
5619 else:
5620 return self.quote(name)
5621 else:
5622 # literal textual elements get stuck into ColumnClause a lot,
5623 # which shouldn't get quoted
5624
5625 if use_table:
5626 return (
5627 self.format_table(
5628 column.table, use_schema=use_schema, name=table_name
5629 )
5630 + "."
5631 + name
5632 )
5633 else:
5634 return name
5635
5636 def format_table_seq(self, table, use_schema=True):
5637 """Format table name and schema as a tuple."""
5638
5639 # Dialects with more levels in their fully qualified references
5640 # ('database', 'owner', etc.) could override this and return
5641 # a longer sequence.
5642
5643 effective_schema = self.schema_for_object(table)
5644
5645 if not self.omit_schema and use_schema and effective_schema:
5646 return (
5647 self.quote_schema(effective_schema),
5648 self.format_table(table, use_schema=False),
5649 )
5650 else:
5651 return (self.format_table(table, use_schema=False),)
5652
5653 @util.memoized_property
5654 def _r_identifiers(self):
5655 initial, final, escaped_final = [
5656 re.escape(s)
5657 for s in (
5658 self.initial_quote,
5659 self.final_quote,
5660 self._escape_identifier(self.final_quote),
5661 )
5662 ]
5663 r = re.compile(
5664 r"(?:"
5665 r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
5666 r"|([^\.]+))(?=\.|$))+"
5667 % {"initial": initial, "final": final, "escaped": escaped_final}
5668 )
5669 return r
5670
5671 def unformat_identifiers(self, identifiers):
5672 """Unpack 'schema.table.column'-like strings into components."""
5673
5674 r = self._r_identifiers
5675 return [
5676 self._unescape_identifier(i)
5677 for i in [a or b for a, b in r.findall(identifiers)]
5678 ]