1# sql/compiler.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Base SQL and DDL compiler implementations.
10
11Classes provided include:
12
13:class:`.compiler.SQLCompiler` - renders SQL
14strings
15
16:class:`.compiler.DDLCompiler` - renders DDL
17(data definition language) strings
18
19:class:`.compiler.GenericTypeCompiler` - renders
20type specification strings.
21
22To generate user-defined SQL strings, see
23:doc:`/ext/compiler`.
24
25"""
26from __future__ import annotations
27
28import collections
29import collections.abc as collections_abc
30import contextlib
31from enum import IntEnum
32import functools
33import itertools
34import operator
35import re
36from time import perf_counter
37import typing
38from typing import Any
39from typing import Callable
40from typing import cast
41from typing import ClassVar
42from typing import Dict
43from typing import FrozenSet
44from typing import Iterable
45from typing import Iterator
46from typing import List
47from typing import Literal
48from typing import Mapping
49from typing import MutableMapping
50from typing import NamedTuple
51from typing import NoReturn
52from typing import Optional
53from typing import Pattern
54from typing import Protocol
55from typing import Sequence
56from typing import Set
57from typing import Tuple
58from typing import Type
59from typing import TYPE_CHECKING
60from typing import TypedDict
61from typing import Union
62
63from . import base
64from . import coercions
65from . import crud
66from . import elements
67from . import functions
68from . import operators
69from . import roles
70from . import schema
71from . import selectable
72from . import sqltypes
73from . import util as sql_util
74from ._typing import is_column_element
75from ._typing import is_dml
76from .base import _de_clone
77from .base import _from_objects
78from .base import _NONE_NAME
79from .base import _SentinelDefaultCharacterization
80from .base import NO_ARG
81from .elements import quoted_name
82from .sqltypes import TupleType
83from .visitors import prefix_anon_map
84from .. import exc
85from .. import util
86from ..util import FastIntFlag
87from ..util.typing import Self
88from ..util.typing import TupleAny
89from ..util.typing import Unpack
90
91if typing.TYPE_CHECKING:
92 from .annotation import _AnnotationDict
93 from .base import _AmbiguousTableNameMap
94 from .base import CompileState
95 from .base import Executable
96 from .cache_key import CacheKey
97 from .ddl import ExecutableDDLElement
98 from .dml import Delete
99 from .dml import Insert
100 from .dml import Update
101 from .dml import UpdateBase
102 from .dml import UpdateDMLState
103 from .dml import ValuesBase
104 from .elements import _truncated_label
105 from .elements import BinaryExpression
106 from .elements import BindParameter
107 from .elements import ClauseElement
108 from .elements import ColumnClause
109 from .elements import ColumnElement
110 from .elements import False_
111 from .elements import Label
112 from .elements import Null
113 from .elements import True_
114 from .functions import Function
115 from .schema import Column
116 from .schema import Constraint
117 from .schema import ForeignKeyConstraint
118 from .schema import Index
119 from .schema import PrimaryKeyConstraint
120 from .schema import Table
121 from .schema import UniqueConstraint
122 from .selectable import _ColumnsClauseElement
123 from .selectable import AliasedReturnsRows
124 from .selectable import CompoundSelectState
125 from .selectable import CTE
126 from .selectable import FromClause
127 from .selectable import NamedFromClause
128 from .selectable import ReturnsRows
129 from .selectable import Select
130 from .selectable import SelectState
131 from .type_api import _BindProcessorType
132 from .type_api import TypeDecorator
133 from .type_api import TypeEngine
134 from .type_api import UserDefinedType
135 from .visitors import Visitable
136 from ..engine.cursor import CursorResultMetaData
137 from ..engine.interfaces import _CoreSingleExecuteParams
138 from ..engine.interfaces import _DBAPIAnyExecuteParams
139 from ..engine.interfaces import _DBAPIMultiExecuteParams
140 from ..engine.interfaces import _DBAPISingleExecuteParams
141 from ..engine.interfaces import _ExecuteOptions
142 from ..engine.interfaces import _GenericSetInputSizesType
143 from ..engine.interfaces import _MutableCoreSingleExecuteParams
144 from ..engine.interfaces import Dialect
145 from ..engine.interfaces import SchemaTranslateMapType
146
147
148_FromHintsType = Dict["FromClause", str]
149
150RESERVED_WORDS = {
151 "all",
152 "analyse",
153 "analyze",
154 "and",
155 "any",
156 "array",
157 "as",
158 "asc",
159 "asymmetric",
160 "authorization",
161 "between",
162 "binary",
163 "both",
164 "case",
165 "cast",
166 "check",
167 "collate",
168 "column",
169 "constraint",
170 "create",
171 "cross",
172 "current_date",
173 "current_role",
174 "current_time",
175 "current_timestamp",
176 "current_user",
177 "default",
178 "deferrable",
179 "desc",
180 "distinct",
181 "do",
182 "else",
183 "end",
184 "except",
185 "false",
186 "for",
187 "foreign",
188 "freeze",
189 "from",
190 "full",
191 "grant",
192 "group",
193 "having",
194 "ilike",
195 "in",
196 "initially",
197 "inner",
198 "intersect",
199 "into",
200 "is",
201 "isnull",
202 "join",
203 "leading",
204 "left",
205 "like",
206 "limit",
207 "localtime",
208 "localtimestamp",
209 "natural",
210 "new",
211 "not",
212 "notnull",
213 "null",
214 "off",
215 "offset",
216 "old",
217 "on",
218 "only",
219 "or",
220 "order",
221 "outer",
222 "overlaps",
223 "placing",
224 "primary",
225 "references",
226 "right",
227 "select",
228 "session_user",
229 "set",
230 "similar",
231 "some",
232 "symmetric",
233 "table",
234 "then",
235 "to",
236 "trailing",
237 "true",
238 "union",
239 "unique",
240 "user",
241 "using",
242 "verbose",
243 "when",
244 "where",
245}
246
247LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
248LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
249ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
250
251FK_ON_DELETE = re.compile(
252 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
253)
254FK_ON_UPDATE = re.compile(
255 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
256)
257FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
258BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
259BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
260
261_pyformat_template = "%%(%(name)s)s"
262BIND_TEMPLATES = {
263 "pyformat": _pyformat_template,
264 "qmark": "?",
265 "format": "%%s",
266 "numeric": ":[_POSITION]",
267 "numeric_dollar": "$[_POSITION]",
268 "named": ":%(name)s",
269}
270
271
272OPERATORS = {
273 # binary
274 operators.and_: " AND ",
275 operators.or_: " OR ",
276 operators.add: " + ",
277 operators.mul: " * ",
278 operators.sub: " - ",
279 operators.mod: " % ",
280 operators.neg: "-",
281 operators.lt: " < ",
282 operators.le: " <= ",
283 operators.ne: " != ",
284 operators.gt: " > ",
285 operators.ge: " >= ",
286 operators.eq: " = ",
287 operators.is_distinct_from: " IS DISTINCT FROM ",
288 operators.is_not_distinct_from: " IS NOT DISTINCT FROM ",
289 operators.concat_op: " || ",
290 operators.match_op: " MATCH ",
291 operators.not_match_op: " NOT MATCH ",
292 operators.in_op: " IN ",
293 operators.not_in_op: " NOT IN ",
294 operators.comma_op: ", ",
295 operators.from_: " FROM ",
296 operators.as_: " AS ",
297 operators.is_: " IS ",
298 operators.is_not: " IS NOT ",
299 operators.collate: " COLLATE ",
300 # unary
301 operators.exists: "EXISTS ",
302 operators.distinct_op: "DISTINCT ",
303 operators.inv: "NOT ",
304 operators.any_op: "ANY ",
305 operators.all_op: "ALL ",
306 # modifiers
307 operators.desc_op: " DESC",
308 operators.asc_op: " ASC",
309 operators.nulls_first_op: " NULLS FIRST",
310 operators.nulls_last_op: " NULLS LAST",
311 # bitwise
312 operators.bitwise_xor_op: " ^ ",
313 operators.bitwise_or_op: " | ",
314 operators.bitwise_and_op: " & ",
315 operators.bitwise_not_op: "~",
316 operators.bitwise_lshift_op: " << ",
317 operators.bitwise_rshift_op: " >> ",
318}
319
320FUNCTIONS: Dict[Type[Function[Any]], str] = {
321 functions.coalesce: "coalesce",
322 functions.current_date: "CURRENT_DATE",
323 functions.current_time: "CURRENT_TIME",
324 functions.current_timestamp: "CURRENT_TIMESTAMP",
325 functions.current_user: "CURRENT_USER",
326 functions.localtime: "LOCALTIME",
327 functions.localtimestamp: "LOCALTIMESTAMP",
328 functions.random: "random",
329 functions.sysdate: "sysdate",
330 functions.session_user: "SESSION_USER",
331 functions.user: "USER",
332 functions.cube: "CUBE",
333 functions.rollup: "ROLLUP",
334 functions.grouping_sets: "GROUPING SETS",
335}
336
337
338EXTRACT_MAP = {
339 "month": "month",
340 "day": "day",
341 "year": "year",
342 "second": "second",
343 "hour": "hour",
344 "doy": "doy",
345 "minute": "minute",
346 "quarter": "quarter",
347 "dow": "dow",
348 "week": "week",
349 "epoch": "epoch",
350 "milliseconds": "milliseconds",
351 "microseconds": "microseconds",
352 "timezone_hour": "timezone_hour",
353 "timezone_minute": "timezone_minute",
354}
355
356COMPOUND_KEYWORDS = {
357 selectable._CompoundSelectKeyword.UNION: "UNION",
358 selectable._CompoundSelectKeyword.UNION_ALL: "UNION ALL",
359 selectable._CompoundSelectKeyword.EXCEPT: "EXCEPT",
360 selectable._CompoundSelectKeyword.EXCEPT_ALL: "EXCEPT ALL",
361 selectable._CompoundSelectKeyword.INTERSECT: "INTERSECT",
362 selectable._CompoundSelectKeyword.INTERSECT_ALL: "INTERSECT ALL",
363}
364
365
366class ResultColumnsEntry(NamedTuple):
367 """Tracks a column expression that is expected to be represented
368 in the result rows for this statement.
369
370 This normally refers to the columns clause of a SELECT statement
371 but may also refer to a RETURNING clause, as well as for dialect-specific
372 emulations.
373
374 """
375
376 keyname: str
377 """string name that's expected in cursor.description"""
378
379 name: str
380 """column name, may be labeled"""
381
382 objects: Tuple[Any, ...]
383 """sequence of objects that should be able to locate this column
384 in a RowMapping. This is typically string names and aliases
385 as well as Column objects.
386
387 """
388
389 type: TypeEngine[Any]
390 """Datatype to be associated with this column. This is where
391 the "result processing" logic directly links the compiled statement
392 to the rows that come back from the cursor.
393
394 """
395
396
397class _ResultMapAppender(Protocol):
398 def __call__(
399 self,
400 keyname: str,
401 name: str,
402 objects: Sequence[Any],
403 type_: TypeEngine[Any],
404 ) -> None: ...
405
406
407# integer indexes into ResultColumnsEntry used by cursor.py.
408# some profiling showed integer access faster than named tuple
409RM_RENDERED_NAME: Literal[0] = 0
410RM_NAME: Literal[1] = 1
411RM_OBJECTS: Literal[2] = 2
412RM_TYPE: Literal[3] = 3
413
414
415class _BaseCompilerStackEntry(TypedDict):
416 asfrom_froms: Set[FromClause]
417 correlate_froms: Set[FromClause]
418 selectable: ReturnsRows
419
420
421class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
422 compile_state: CompileState
423 need_result_map_for_nested: bool
424 need_result_map_for_compound: bool
425 select_0: ReturnsRows
426 insert_from_select: Select[Unpack[TupleAny]]
427
428
429class ExpandedState(NamedTuple):
430 """represents state to use when producing "expanded" and
431 "post compile" bound parameters for a statement.
432
433 "expanded" parameters are parameters that are generated at
434 statement execution time to suit a number of parameters passed, the most
435 prominent example being the individual elements inside of an IN expression.
436
437 "post compile" parameters are parameters where the SQL literal value
438 will be rendered into the SQL statement at execution time, rather than
439 being passed as separate parameters to the driver.
440
441 To create an :class:`.ExpandedState` instance, use the
442 :meth:`.SQLCompiler.construct_expanded_state` method on any
443 :class:`.SQLCompiler` instance.
444
445 """
446
447 statement: str
448 """String SQL statement with parameters fully expanded"""
449
450 parameters: _CoreSingleExecuteParams
451 """Parameter dictionary with parameters fully expanded.
452
453 For a statement that uses named parameters, this dictionary will map
454 exactly to the names in the statement. For a statement that uses
455 positional parameters, the :attr:`.ExpandedState.positional_parameters`
456 will yield a tuple with the positional parameter set.
457
458 """
459
460 processors: Mapping[str, _BindProcessorType[Any]]
461 """mapping of bound value processors"""
462
463 positiontup: Optional[Sequence[str]]
464 """Sequence of string names indicating the order of positional
465 parameters"""
466
467 parameter_expansion: Mapping[str, List[str]]
468 """Mapping representing the intermediary link from original parameter
469 name to list of "expanded" parameter names, for those parameters that
470 were expanded."""
471
472 @property
473 def positional_parameters(self) -> Tuple[Any, ...]:
474 """Tuple of positional parameters, for statements that were compiled
475 using a positional paramstyle.
476
477 """
478 if self.positiontup is None:
479 raise exc.InvalidRequestError(
480 "statement does not use a positional paramstyle"
481 )
482 return tuple(self.parameters[key] for key in self.positiontup)
483
484 @property
485 def additional_parameters(self) -> _CoreSingleExecuteParams:
486 """synonym for :attr:`.ExpandedState.parameters`."""
487 return self.parameters
488
489
490class _InsertManyValues(NamedTuple):
491 """represents state to use for executing an "insertmanyvalues" statement.
492
493 The primary consumers of this object are the
494 :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and
495 :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods.
496
497 .. versionadded:: 2.0
498
499 """
500
501 is_default_expr: bool
502 """if True, the statement is of the form
503 ``INSERT INTO TABLE DEFAULT VALUES``, and can't be rewritten as a "batch"
504
505 """
506
507 single_values_expr: str
508 """The rendered "values" clause of the INSERT statement.
509
510 This is typically the parenthesized section e.g. "(?, ?, ?)" or similar.
511 The insertmanyvalues logic uses this string as a search and replace
512 target.
513
514 """
515
516 insert_crud_params: List[crud._CrudParamElementStr]
517 """List of Column / bind names etc. used while rewriting the statement"""
518
519 num_positional_params_counted: int
520 """the number of bound parameters in a single-row statement.
521
522 This count may be larger or smaller than the actual number of columns
523 targeted in the INSERT, as it accommodates for SQL expressions
524 in the values list that may have zero or more parameters embedded
525 within them.
526
527 This count is part of what's used to organize rewritten parameter lists
528 when batching.
529
530 """
531
532 sort_by_parameter_order: bool = False
533 """if the deterministic_returnined_order parameter were used on the
534 insert.
535
536 All of the attributes following this will only be used if this is True.
537
538 """
539
540 includes_upsert_behaviors: bool = False
541 """if True, we have to accommodate for upsert behaviors.
542
543 This will in some cases downgrade "insertmanyvalues" that requests
544 deterministic ordering.
545
546 """
547
548 sentinel_columns: Optional[Sequence[Column[Any]]] = None
549 """List of sentinel columns that were located.
550
551 This list is only here if the INSERT asked for
552 sort_by_parameter_order=True,
553 and dialect-appropriate sentinel columns were located.
554
555 .. versionadded:: 2.0.10
556
557 """
558
559 num_sentinel_columns: int = 0
560 """how many sentinel columns are in the above list, if any.
561
562 This is the same as
563 ``len(sentinel_columns) if sentinel_columns is not None else 0``
564
565 """
566
567 sentinel_param_keys: Optional[Sequence[str]] = None
568 """parameter str keys in each param dictionary / tuple
569 that would link to the client side "sentinel" values for that row, which
570 we can use to match up parameter sets to result rows.
571
572 This is only present if sentinel_columns is present and the INSERT
573 statement actually refers to client side values for these sentinel
574 columns.
575
576 .. versionadded:: 2.0.10
577
578 .. versionchanged:: 2.0.29 - the sequence is now string dictionary keys
579 only, used against the "compiled parameteters" collection before
580 the parameters were converted by bound parameter processors
581
582 """
583
584 implicit_sentinel: bool = False
585 """if True, we have exactly one sentinel column and it uses a server side
586 value, currently has to generate an incrementing integer value.
587
588 The dialect in question would have asserted that it supports receiving
589 these values back and sorting on that value as a means of guaranteeing
590 correlation with the incoming parameter list.
591
592 .. versionadded:: 2.0.10
593
594 """
595
596 embed_values_counter: bool = False
597 """Whether to embed an incrementing integer counter in each parameter
598 set within the VALUES clause as parameters are batched over.
599
600 This is only used for a specific INSERT..SELECT..VALUES..RETURNING syntax
601 where a subquery is used to produce value tuples. Current support
602 includes PostgreSQL, Microsoft SQL Server.
603
604 .. versionadded:: 2.0.10
605
606 """
607
608
609class _InsertManyValuesBatch(NamedTuple):
610 """represents an individual batch SQL statement for insertmanyvalues.
611
612 This is passed through the
613 :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and
614 :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods out
615 to the :class:`.Connection` within the
616 :meth:`.Connection._exec_insertmany_context` method.
617
618 .. versionadded:: 2.0.10
619
620 """
621
622 replaced_statement: str
623 replaced_parameters: _DBAPIAnyExecuteParams
624 processed_setinputsizes: Optional[_GenericSetInputSizesType]
625 batch: Sequence[_DBAPISingleExecuteParams]
626 sentinel_values: Sequence[Tuple[Any, ...]]
627 current_batch_size: int
628 batchnum: int
629 total_batches: int
630 rows_sorted: bool
631 is_downgraded: bool
632
633
634class InsertmanyvaluesSentinelOpts(FastIntFlag):
635 """bitflag enum indicating styles of PK defaults
636 which can work as implicit sentinel columns
637
638 """
639
640 NOT_SUPPORTED = 1
641 AUTOINCREMENT = 2
642 IDENTITY = 4
643 SEQUENCE = 8
644
645 ANY_AUTOINCREMENT = AUTOINCREMENT | IDENTITY | SEQUENCE
646 _SUPPORTED_OR_NOT = NOT_SUPPORTED | ANY_AUTOINCREMENT
647
648 USE_INSERT_FROM_SELECT = 16
649 RENDER_SELECT_COL_CASTS = 64
650
651
652class CompilerState(IntEnum):
653 COMPILING = 0
654 """statement is present, compilation phase in progress"""
655
656 STRING_APPLIED = 1
657 """statement is present, string form of the statement has been applied.
658
659 Additional processors by subclasses may still be pending.
660
661 """
662
663 NO_STATEMENT = 2
664 """compiler does not have a statement to compile, is used
665 for method access"""
666
667
668class Linting(IntEnum):
669 """represent preferences for the 'SQL linting' feature.
670
671 this feature currently includes support for flagging cartesian products
672 in SQL statements.
673
674 """
675
676 NO_LINTING = 0
677 "Disable all linting."
678
679 COLLECT_CARTESIAN_PRODUCTS = 1
680 """Collect data on FROMs and cartesian products and gather into
681 'self.from_linter'"""
682
683 WARN_LINTING = 2
684 "Emit warnings for linters that find problems"
685
686 FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING
687 """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS
688 and WARN_LINTING"""
689
690
691NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple(
692 Linting
693)
694
695
696class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
697 """represents current state for the "cartesian product" detection
698 feature."""
699
700 def lint(self, start=None):
701 froms = self.froms
702 if not froms:
703 return None, None
704
705 edges = set(self.edges)
706 the_rest = set(froms)
707
708 if start is not None:
709 start_with = start
710 the_rest.remove(start_with)
711 else:
712 start_with = the_rest.pop()
713
714 stack = collections.deque([start_with])
715
716 while stack and the_rest:
717 node = stack.popleft()
718 the_rest.discard(node)
719
720 # comparison of nodes in edges here is based on hash equality, as
721 # there are "annotated" elements that match the non-annotated ones.
722 # to remove the need for in-python hash() calls, use native
723 # containment routines (e.g. "node in edge", "edge.index(node)")
724 to_remove = {edge for edge in edges if node in edge}
725
726 # appendleft the node in each edge that is not
727 # the one that matched.
728 stack.extendleft(edge[not edge.index(node)] for edge in to_remove)
729 edges.difference_update(to_remove)
730
731 # FROMS left over? boom
732 if the_rest:
733 return the_rest, start_with
734 else:
735 return None, None
736
737 def warn(self, stmt_type="SELECT"):
738 the_rest, start_with = self.lint()
739
740 # FROMS left over? boom
741 if the_rest:
742 froms = the_rest
743 if froms:
744 template = (
745 "{stmt_type} statement has a cartesian product between "
746 "FROM element(s) {froms} and "
747 'FROM element "{start}". Apply join condition(s) '
748 "between each element to resolve."
749 )
750 froms_str = ", ".join(
751 f'"{self.froms[from_]}"' for from_ in froms
752 )
753 message = template.format(
754 stmt_type=stmt_type,
755 froms=froms_str,
756 start=self.froms[start_with],
757 )
758
759 util.warn(message)
760
761
762class Compiled:
763 """Represent a compiled SQL or DDL expression.
764
765 The ``__str__`` method of the ``Compiled`` object should produce
766 the actual text of the statement. ``Compiled`` objects are
767 specific to their underlying database dialect, and also may
768 or may not be specific to the columns referenced within a
769 particular set of bind parameters. In no case should the
770 ``Compiled`` object be dependent on the actual values of those
771 bind parameters, even though it may reference those values as
772 defaults.
773 """
774
775 statement: Optional[ClauseElement] = None
776 "The statement to compile."
777 string: str = ""
778 "The string representation of the ``statement``"
779
780 state: CompilerState
781 """description of the compiler's state"""
782
783 is_sql = False
784 is_ddl = False
785
786 _cached_metadata: Optional[CursorResultMetaData] = None
787
788 _result_columns: Optional[List[ResultColumnsEntry]] = None
789
790 schema_translate_map: Optional[SchemaTranslateMapType] = None
791
792 execution_options: _ExecuteOptions = util.EMPTY_DICT
793 """
794 Execution options propagated from the statement. In some cases,
795 sub-elements of the statement can modify these.
796 """
797
798 preparer: IdentifierPreparer
799
800 _annotations: _AnnotationDict = util.EMPTY_DICT
801
802 compile_state: Optional[CompileState] = None
803 """Optional :class:`.CompileState` object that maintains additional
804 state used by the compiler.
805
806 Major executable objects such as :class:`_expression.Insert`,
807 :class:`_expression.Update`, :class:`_expression.Delete`,
808 :class:`_expression.Select` will generate this
809 state when compiled in order to calculate additional information about the
810 object. For the top level object that is to be executed, the state can be
811 stored here where it can also have applicability towards result set
812 processing.
813
814 .. versionadded:: 1.4
815
816 """
817
818 dml_compile_state: Optional[CompileState] = None
819 """Optional :class:`.CompileState` assigned at the same point that
820 .isinsert, .isupdate, or .isdelete is assigned.
821
822 This will normally be the same object as .compile_state, with the
823 exception of cases like the :class:`.ORMFromStatementCompileState`
824 object.
825
826 .. versionadded:: 1.4.40
827
828 """
829
830 cache_key: Optional[CacheKey] = None
831 """The :class:`.CacheKey` that was generated ahead of creating this
832 :class:`.Compiled` object.
833
834 This is used for routines that need access to the original
835 :class:`.CacheKey` instance generated when the :class:`.Compiled`
836 instance was first cached, typically in order to reconcile
837 the original list of :class:`.BindParameter` objects with a
838 per-statement list that's generated on each call.
839
840 """
841
842 _gen_time: float
843 """Generation time of this :class:`.Compiled`, used for reporting
844 cache stats."""
845
846 def __init__(
847 self,
848 dialect: Dialect,
849 statement: Optional[ClauseElement],
850 schema_translate_map: Optional[SchemaTranslateMapType] = None,
851 render_schema_translate: bool = False,
852 compile_kwargs: Mapping[str, Any] = util.immutabledict(),
853 ):
854 """Construct a new :class:`.Compiled` object.
855
856 :param dialect: :class:`.Dialect` to compile against.
857
858 :param statement: :class:`_expression.ClauseElement` to be compiled.
859
860 :param schema_translate_map: dictionary of schema names to be
861 translated when forming the resultant SQL
862
863 .. seealso::
864
865 :ref:`schema_translating`
866
867 :param compile_kwargs: additional kwargs that will be
868 passed to the initial call to :meth:`.Compiled.process`.
869
870
871 """
872 self.dialect = dialect
873 self.preparer = self.dialect.identifier_preparer
874 if schema_translate_map:
875 self.schema_translate_map = schema_translate_map
876 self.preparer = self.preparer._with_schema_translate(
877 schema_translate_map
878 )
879
880 if statement is not None:
881 self.state = CompilerState.COMPILING
882 self.statement = statement
883 self.can_execute = statement.supports_execution
884 self._annotations = statement._annotations
885 if self.can_execute:
886 if TYPE_CHECKING:
887 assert isinstance(statement, Executable)
888 self.execution_options = statement._execution_options
889 self.string = self.process(self.statement, **compile_kwargs)
890
891 if render_schema_translate:
892 assert schema_translate_map is not None
893 self.string = self.preparer._render_schema_translates(
894 self.string, schema_translate_map
895 )
896
897 self.state = CompilerState.STRING_APPLIED
898 else:
899 self.state = CompilerState.NO_STATEMENT
900
901 self._gen_time = perf_counter()
902
903 def __init_subclass__(cls) -> None:
904 cls._init_compiler_cls()
905 return super().__init_subclass__()
906
907 @classmethod
908 def _init_compiler_cls(cls):
909 pass
910
911 def _execute_on_connection(
912 self, connection, distilled_params, execution_options
913 ):
914 if self.can_execute:
915 return connection._execute_compiled(
916 self, distilled_params, execution_options
917 )
918 else:
919 raise exc.ObjectNotExecutableError(self.statement)
920
921 def visit_unsupported_compilation(self, element, err, **kw):
922 raise exc.UnsupportedCompilationError(self, type(element)) from err
923
924 @property
925 def sql_compiler(self) -> SQLCompiler:
926 """Return a Compiled that is capable of processing SQL expressions.
927
928 If this compiler is one, it would likely just return 'self'.
929
930 """
931
932 raise NotImplementedError()
933
934 def process(self, obj: Visitable, **kwargs: Any) -> str:
935 return obj._compiler_dispatch(self, **kwargs)
936
937 def __str__(self) -> str:
938 """Return the string text of the generated SQL or DDL."""
939
940 if self.state is CompilerState.STRING_APPLIED:
941 return self.string
942 else:
943 return ""
944
945 def construct_params(
946 self,
947 params: Optional[_CoreSingleExecuteParams] = None,
948 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
949 escape_names: bool = True,
950 ) -> Optional[_MutableCoreSingleExecuteParams]:
951 """Return the bind params for this compiled object.
952
953 :param params: a dict of string/object pairs whose values will
954 override bind values compiled in to the
955 statement.
956 """
957
958 raise NotImplementedError()
959
960 @property
961 def params(self):
962 """Return the bind params for this compiled object."""
963 return self.construct_params()
964
965
966class TypeCompiler(util.EnsureKWArg):
967 """Produces DDL specification for TypeEngine objects."""
968
969 ensure_kwarg = r"visit_\w+"
970
971 def __init__(self, dialect: Dialect):
972 self.dialect = dialect
973
974 def process(self, type_: TypeEngine[Any], **kw: Any) -> str:
975 if (
976 type_._variant_mapping
977 and self.dialect.name in type_._variant_mapping
978 ):
979 type_ = type_._variant_mapping[self.dialect.name]
980 return type_._compiler_dispatch(self, **kw)
981
982 def visit_unsupported_compilation(
983 self, element: Any, err: Exception, **kw: Any
984 ) -> NoReturn:
985 raise exc.UnsupportedCompilationError(self, element) from err
986
987
988# this was a Visitable, but to allow accurate detection of
989# column elements this is actually a column element
990class _CompileLabel(
991 roles.BinaryElementRole[Any], elements.CompilerColumnElement
992):
993 """lightweight label object which acts as an expression.Label."""
994
995 __visit_name__ = "label"
996 __slots__ = "element", "name", "_alt_names"
997
998 def __init__(self, col, name, alt_names=()):
999 self.element = col
1000 self.name = name
1001 self._alt_names = (col,) + alt_names
1002
1003 @property
1004 def proxy_set(self):
1005 return self.element.proxy_set
1006
1007 @property
1008 def type(self):
1009 return self.element.type
1010
1011 def self_group(self, **kw):
1012 return self
1013
1014
1015class ilike_case_insensitive(
1016 roles.BinaryElementRole[Any], elements.CompilerColumnElement
1017):
1018 """produce a wrapping element for a case-insensitive portion of
1019 an ILIKE construct.
1020
1021 The construct usually renders the ``lower()`` function, but on
1022 PostgreSQL will pass silently with the assumption that "ILIKE"
1023 is being used.
1024
1025 .. versionadded:: 2.0
1026
1027 """
1028
1029 __visit_name__ = "ilike_case_insensitive_operand"
1030 __slots__ = "element", "comparator"
1031
1032 def __init__(self, element):
1033 self.element = element
1034 self.comparator = element.comparator
1035
1036 @property
1037 def proxy_set(self):
1038 return self.element.proxy_set
1039
1040 @property
1041 def type(self):
1042 return self.element.type
1043
1044 def self_group(self, **kw):
1045 return self
1046
1047 def _with_binary_element_type(self, type_):
1048 return ilike_case_insensitive(
1049 self.element._with_binary_element_type(type_)
1050 )
1051
1052
1053class SQLCompiler(Compiled):
1054 """Default implementation of :class:`.Compiled`.
1055
1056 Compiles :class:`_expression.ClauseElement` objects into SQL strings.
1057
1058 """
1059
1060 extract_map = EXTRACT_MAP
1061
1062 bindname_escape_characters: ClassVar[Mapping[str, str]] = (
1063 util.immutabledict(
1064 {
1065 "%": "P",
1066 "(": "A",
1067 ")": "Z",
1068 ":": "C",
1069 ".": "_",
1070 "[": "_",
1071 "]": "_",
1072 " ": "_",
1073 }
1074 )
1075 )
1076 """A mapping (e.g. dict or similar) containing a lookup of
1077 characters keyed to replacement characters which will be applied to all
1078 'bind names' used in SQL statements as a form of 'escaping'; the given
1079 characters are replaced entirely with the 'replacement' character when
1080 rendered in the SQL statement, and a similar translation is performed
1081 on the incoming names used in parameter dictionaries passed to methods
1082 like :meth:`_engine.Connection.execute`.
1083
1084 This allows bound parameter names used in :func:`_sql.bindparam` and
1085 other constructs to have any arbitrary characters present without any
1086 concern for characters that aren't allowed at all on the target database.
1087
1088 Third party dialects can establish their own dictionary here to replace the
1089 default mapping, which will ensure that the particular characters in the
1090 mapping will never appear in a bound parameter name.
1091
1092 The dictionary is evaluated at **class creation time**, so cannot be
1093 modified at runtime; it must be present on the class when the class
1094 is first declared.
1095
1096 Note that for dialects that have additional bound parameter rules such
1097 as additional restrictions on leading characters, the
1098 :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented.
1099 See the cx_Oracle compiler for an example of this.
1100
1101 .. versionadded:: 2.0.0rc1
1102
1103 """
1104
1105 _bind_translate_re: ClassVar[Pattern[str]]
1106 _bind_translate_chars: ClassVar[Mapping[str, str]]
1107
1108 is_sql = True
1109
1110 compound_keywords = COMPOUND_KEYWORDS
1111
1112 isdelete: bool = False
1113 isinsert: bool = False
1114 isupdate: bool = False
1115 """class-level defaults which can be set at the instance
1116 level to define if this Compiled instance represents
1117 INSERT/UPDATE/DELETE
1118 """
1119
1120 postfetch: Optional[List[Column[Any]]]
1121 """list of columns that can be post-fetched after INSERT or UPDATE to
1122 receive server-updated values"""
1123
1124 insert_prefetch: Sequence[Column[Any]] = ()
1125 """list of columns for which default values should be evaluated before
1126 an INSERT takes place"""
1127
1128 update_prefetch: Sequence[Column[Any]] = ()
1129 """list of columns for which onupdate default values should be evaluated
1130 before an UPDATE takes place"""
1131
1132 implicit_returning: Optional[Sequence[ColumnElement[Any]]] = None
1133 """list of "implicit" returning columns for a toplevel INSERT or UPDATE
1134 statement, used to receive newly generated values of columns.
1135
1136 .. versionadded:: 2.0 ``implicit_returning`` replaces the previous
1137 ``returning`` collection, which was not a generalized RETURNING
1138 collection and instead was in fact specific to the "implicit returning"
1139 feature.
1140
1141 """
1142
1143 isplaintext: bool = False
1144
1145 binds: Dict[str, BindParameter[Any]]
1146 """a dictionary of bind parameter keys to BindParameter instances."""
1147
1148 bind_names: Dict[BindParameter[Any], str]
1149 """a dictionary of BindParameter instances to "compiled" names
1150 that are actually present in the generated SQL"""
1151
1152 stack: List[_CompilerStackEntry]
1153 """major statements such as SELECT, INSERT, UPDATE, DELETE are
1154 tracked in this stack using an entry format."""
1155
1156 returning_precedes_values: bool = False
1157 """set to True classwide to generate RETURNING
1158 clauses before the VALUES or WHERE clause (i.e. MSSQL)
1159 """
1160
1161 render_table_with_column_in_update_from: bool = False
1162 """set to True classwide to indicate the SET clause
1163 in a multi-table UPDATE statement should qualify
1164 columns with the table name (i.e. MySQL only)
1165 """
1166
1167 ansi_bind_rules: bool = False
1168 """SQL 92 doesn't allow bind parameters to be used
1169 in the columns clause of a SELECT, nor does it allow
1170 ambiguous expressions like "? = ?". A compiler
1171 subclass can set this flag to False if the target
1172 driver/DB enforces this
1173 """
1174
1175 bindtemplate: str
1176 """template to render bound parameters based on paramstyle."""
1177
1178 compilation_bindtemplate: str
1179 """template used by compiler to render parameters before positional
1180 paramstyle application"""
1181
1182 _numeric_binds_identifier_char: str
1183 """Character that's used to as the identifier of a numerical bind param.
1184 For example if this char is set to ``$``, numerical binds will be rendered
1185 in the form ``$1, $2, $3``.
1186 """
1187
1188 _result_columns: List[ResultColumnsEntry]
1189 """relates label names in the final SQL to a tuple of local
1190 column/label name, ColumnElement object (if any) and
1191 TypeEngine. CursorResult uses this for type processing and
1192 column targeting"""
1193
1194 _textual_ordered_columns: bool = False
1195 """tell the result object that the column names as rendered are important,
1196 but they are also "ordered" vs. what is in the compiled object here.
1197
1198 As of 1.4.42 this condition is only present when the statement is a
1199 TextualSelect, e.g. text("....").columns(...), where it is required
1200 that the columns are considered positionally and not by name.
1201
1202 """
1203
1204 _ad_hoc_textual: bool = False
1205 """tell the result that we encountered text() or '*' constructs in the
1206 middle of the result columns, but we also have compiled columns, so
1207 if the number of columns in cursor.description does not match how many
1208 expressions we have, that means we can't rely on positional at all and
1209 should match on name.
1210
1211 """
1212
1213 _ordered_columns: bool = True
1214 """
1215 if False, means we can't be sure the list of entries
1216 in _result_columns is actually the rendered order. Usually
1217 True unless using an unordered TextualSelect.
1218 """
1219
1220 _loose_column_name_matching: bool = False
1221 """tell the result object that the SQL statement is textual, wants to match
1222 up to Column objects, and may be using the ._tq_label in the SELECT rather
1223 than the base name.
1224
1225 """
1226
1227 _numeric_binds: bool = False
1228 """
1229 True if paramstyle is "numeric". This paramstyle is trickier than
1230 all the others.
1231
1232 """
1233
1234 _render_postcompile: bool = False
1235 """
1236 whether to render out POSTCOMPILE params during the compile phase.
1237
1238 This attribute is used only for end-user invocation of stmt.compile();
1239 it's never used for actual statement execution, where instead the
1240 dialect internals access and render the internal postcompile structure
1241 directly.
1242
1243 """
1244
1245 _post_compile_expanded_state: Optional[ExpandedState] = None
1246 """When render_postcompile is used, the ``ExpandedState`` used to create
1247 the "expanded" SQL is assigned here, and then used by the ``.params``
1248 accessor and ``.construct_params()`` methods for their return values.
1249
1250 .. versionadded:: 2.0.0rc1
1251
1252 """
1253
1254 _pre_expanded_string: Optional[str] = None
1255 """Stores the original string SQL before 'post_compile' is applied,
1256 for cases where 'post_compile' were used.
1257
1258 """
1259
1260 _pre_expanded_positiontup: Optional[List[str]] = None
1261
1262 _insertmanyvalues: Optional[_InsertManyValues] = None
1263
1264 _insert_crud_params: Optional[crud._CrudParamSequence] = None
1265
1266 literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
1267 """bindparameter objects that are rendered as literal values at statement
1268 execution time.
1269
1270 """
1271
1272 post_compile_params: FrozenSet[BindParameter[Any]] = frozenset()
1273 """bindparameter objects that are rendered as bound parameter placeholders
1274 at statement execution time.
1275
1276 """
1277
1278 escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT
1279 """Late escaping of bound parameter names that has to be converted
1280 to the original name when looking in the parameter dictionary.
1281
1282 """
1283
1284 has_out_parameters = False
1285 """if True, there are bindparam() objects that have the isoutparam
1286 flag set."""
1287
1288 postfetch_lastrowid = False
1289 """if True, and this in insert, use cursor.lastrowid to populate
1290 result.inserted_primary_key. """
1291
1292 _cache_key_bind_match: Optional[
1293 Tuple[
1294 Dict[
1295 BindParameter[Any],
1296 List[BindParameter[Any]],
1297 ],
1298 Dict[
1299 str,
1300 BindParameter[Any],
1301 ],
1302 ]
1303 ] = None
1304 """a mapping that will relate the BindParameter object we compile
1305 to those that are part of the extracted collection of parameters
1306 in the cache key, if we were given a cache key.
1307
1308 """
1309
1310 positiontup: Optional[List[str]] = None
1311 """for a compiled construct that uses a positional paramstyle, will be
1312 a sequence of strings, indicating the names of bound parameters in order.
1313
1314 This is used in order to render bound parameters in their correct order,
1315 and is combined with the :attr:`_sql.Compiled.params` dictionary to
1316 render parameters.
1317
1318 This sequence always contains the unescaped name of the parameters.
1319
1320 .. seealso::
1321
1322 :ref:`faq_sql_expression_string` - includes a usage example for
1323 debugging use cases.
1324
1325 """
1326 _values_bindparam: Optional[List[str]] = None
1327
1328 _visited_bindparam: Optional[List[str]] = None
1329
1330 inline: bool = False
1331
1332 ctes: Optional[MutableMapping[CTE, str]]
1333
1334 # Detect same CTE references - Dict[(level, name), cte]
1335 # Level is required for supporting nesting
1336 ctes_by_level_name: Dict[Tuple[int, str], CTE]
1337
1338 # To retrieve key/level in ctes_by_level_name -
1339 # Dict[cte_reference, (level, cte_name, cte_opts)]
1340 level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
1341
1342 ctes_recursive: bool
1343
1344 _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]")
1345 _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s")
1346 _positional_pattern = re.compile(
1347 f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
1348 )
1349
1350 @classmethod
1351 def _init_compiler_cls(cls):
1352 cls._init_bind_translate()
1353
1354 @classmethod
1355 def _init_bind_translate(cls):
1356 reg = re.escape("".join(cls.bindname_escape_characters))
1357 cls._bind_translate_re = re.compile(f"[{reg}]")
1358 cls._bind_translate_chars = cls.bindname_escape_characters
1359
1360 def __init__(
1361 self,
1362 dialect: Dialect,
1363 statement: Optional[ClauseElement],
1364 cache_key: Optional[CacheKey] = None,
1365 column_keys: Optional[Sequence[str]] = None,
1366 for_executemany: bool = False,
1367 linting: Linting = NO_LINTING,
1368 _supporting_against: Optional[SQLCompiler] = None,
1369 **kwargs: Any,
1370 ):
1371 """Construct a new :class:`.SQLCompiler` object.
1372
1373 :param dialect: :class:`.Dialect` to be used
1374
1375 :param statement: :class:`_expression.ClauseElement` to be compiled
1376
1377 :param column_keys: a list of column names to be compiled into an
1378 INSERT or UPDATE statement.
1379
1380 :param for_executemany: whether INSERT / UPDATE statements should
1381 expect that they are to be invoked in an "executemany" style,
1382 which may impact how the statement will be expected to return the
1383 values of defaults and autoincrement / sequences and similar.
1384 Depending on the backend and driver in use, support for retrieving
1385 these values may be disabled which means SQL expressions may
1386 be rendered inline, RETURNING may not be rendered, etc.
1387
1388 :param kwargs: additional keyword arguments to be consumed by the
1389 superclass.
1390
1391 """
1392 self.column_keys = column_keys
1393
1394 self.cache_key = cache_key
1395
1396 if cache_key:
1397 cksm = {b.key: b for b in cache_key[1]}
1398 ckbm = {b: [b] for b in cache_key[1]}
1399 self._cache_key_bind_match = (ckbm, cksm)
1400
1401 # compile INSERT/UPDATE defaults/sequences to expect executemany
1402 # style execution, which may mean no pre-execute of defaults,
1403 # or no RETURNING
1404 self.for_executemany = for_executemany
1405
1406 self.linting = linting
1407
1408 # a dictionary of bind parameter keys to BindParameter
1409 # instances.
1410 self.binds = {}
1411
1412 # a dictionary of BindParameter instances to "compiled" names
1413 # that are actually present in the generated SQL
1414 self.bind_names = util.column_dict()
1415
1416 # stack which keeps track of nested SELECT statements
1417 self.stack = []
1418
1419 self._result_columns = []
1420
1421 # true if the paramstyle is positional
1422 self.positional = dialect.positional
1423 if self.positional:
1424 self._numeric_binds = nb = dialect.paramstyle.startswith("numeric")
1425 if nb:
1426 self._numeric_binds_identifier_char = (
1427 "$" if dialect.paramstyle == "numeric_dollar" else ":"
1428 )
1429
1430 self.compilation_bindtemplate = _pyformat_template
1431 else:
1432 self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
1433
1434 self.ctes = None
1435
1436 self.label_length = (
1437 dialect.label_length or dialect.max_identifier_length
1438 )
1439
1440 # a map which tracks "anonymous" identifiers that are created on
1441 # the fly here
1442 self.anon_map = prefix_anon_map()
1443
1444 # a map which tracks "truncated" names based on
1445 # dialect.label_length or dialect.max_identifier_length
1446 self.truncated_names: Dict[Tuple[str, str], str] = {}
1447 self._truncated_counters: Dict[str, int] = {}
1448
1449 Compiled.__init__(self, dialect, statement, **kwargs)
1450
1451 if self.isinsert or self.isupdate or self.isdelete:
1452 if TYPE_CHECKING:
1453 assert isinstance(statement, UpdateBase)
1454
1455 if self.isinsert or self.isupdate:
1456 if TYPE_CHECKING:
1457 assert isinstance(statement, ValuesBase)
1458 if statement._inline:
1459 self.inline = True
1460 elif self.for_executemany and (
1461 not self.isinsert
1462 or (
1463 self.dialect.insert_executemany_returning
1464 and statement._return_defaults
1465 )
1466 ):
1467 self.inline = True
1468
1469 self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
1470
1471 if _supporting_against:
1472 self.__dict__.update(
1473 {
1474 k: v
1475 for k, v in _supporting_against.__dict__.items()
1476 if k
1477 not in {
1478 "state",
1479 "dialect",
1480 "preparer",
1481 "positional",
1482 "_numeric_binds",
1483 "compilation_bindtemplate",
1484 "bindtemplate",
1485 }
1486 }
1487 )
1488
1489 if self.state is CompilerState.STRING_APPLIED:
1490 if self.positional:
1491 if self._numeric_binds:
1492 self._process_numeric()
1493 else:
1494 self._process_positional()
1495
1496 if self._render_postcompile:
1497 parameters = self.construct_params(
1498 escape_names=False,
1499 _no_postcompile=True,
1500 )
1501
1502 self._process_parameters_for_postcompile(
1503 parameters, _populate_self=True
1504 )
1505
1506 @property
1507 def insert_single_values_expr(self) -> Optional[str]:
1508 """When an INSERT is compiled with a single set of parameters inside
1509 a VALUES expression, the string is assigned here, where it can be
1510 used for insert batching schemes to rewrite the VALUES expression.
1511
1512 .. versionchanged:: 2.0 This collection is no longer used by
1513 SQLAlchemy's built-in dialects, in favor of the currently
1514 internal ``_insertmanyvalues`` collection that is used only by
1515 :class:`.SQLCompiler`.
1516
1517 """
1518 if self._insertmanyvalues is None:
1519 return None
1520 else:
1521 return self._insertmanyvalues.single_values_expr
1522
1523 @util.ro_memoized_property
1524 def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]:
1525 """The effective "returning" columns for INSERT, UPDATE or DELETE.
1526
1527 This is either the so-called "implicit returning" columns which are
1528 calculated by the compiler on the fly, or those present based on what's
1529 present in ``self.statement._returning`` (expanded into individual
1530 columns using the ``._all_selected_columns`` attribute) i.e. those set
1531 explicitly using the :meth:`.UpdateBase.returning` method.
1532
1533 .. versionadded:: 2.0
1534
1535 """
1536 if self.implicit_returning:
1537 return self.implicit_returning
1538 elif self.statement is not None and is_dml(self.statement):
1539 return [
1540 c
1541 for c in self.statement._all_selected_columns
1542 if is_column_element(c)
1543 ]
1544
1545 else:
1546 return None
1547
1548 @property
1549 def returning(self):
1550 """backwards compatibility; returns the
1551 effective_returning collection.
1552
1553 """
1554 return self.effective_returning
1555
1556 @property
1557 def current_executable(self):
1558 """Return the current 'executable' that is being compiled.
1559
1560 This is currently the :class:`_sql.Select`, :class:`_sql.Insert`,
1561 :class:`_sql.Update`, :class:`_sql.Delete`,
1562 :class:`_sql.CompoundSelect` object that is being compiled.
1563 Specifically it's assigned to the ``self.stack`` list of elements.
1564
1565 When a statement like the above is being compiled, it normally
1566 is also assigned to the ``.statement`` attribute of the
1567 :class:`_sql.Compiler` object. However, all SQL constructs are
1568 ultimately nestable, and this attribute should never be consulted
1569 by a ``visit_`` method, as it is not guaranteed to be assigned
1570 nor guaranteed to correspond to the current statement being compiled.
1571
1572 """
1573 try:
1574 return self.stack[-1]["selectable"]
1575 except IndexError as ie:
1576 raise IndexError("Compiler does not have a stack entry") from ie
1577
1578 @property
1579 def prefetch(self):
1580 return list(self.insert_prefetch) + list(self.update_prefetch)
1581
1582 @util.memoized_property
1583 def _global_attributes(self) -> Dict[Any, Any]:
1584 return {}
1585
1586 @util.memoized_instancemethod
1587 def _init_cte_state(self) -> MutableMapping[CTE, str]:
1588 """Initialize collections related to CTEs only if
1589 a CTE is located, to save on the overhead of
1590 these collections otherwise.
1591
1592 """
1593 # collect CTEs to tack on top of a SELECT
1594 # To store the query to print - Dict[cte, text_query]
1595 ctes: MutableMapping[CTE, str] = util.OrderedDict()
1596 self.ctes = ctes
1597
1598 # Detect same CTE references - Dict[(level, name), cte]
1599 # Level is required for supporting nesting
1600 self.ctes_by_level_name = {}
1601
1602 # To retrieve key/level in ctes_by_level_name -
1603 # Dict[cte_reference, (level, cte_name, cte_opts)]
1604 self.level_name_by_cte = {}
1605
1606 self.ctes_recursive = False
1607
1608 return ctes
1609
1610 @contextlib.contextmanager
1611 def _nested_result(self):
1612 """special API to support the use case of 'nested result sets'"""
1613 result_columns, ordered_columns = (
1614 self._result_columns,
1615 self._ordered_columns,
1616 )
1617 self._result_columns, self._ordered_columns = [], False
1618
1619 try:
1620 if self.stack:
1621 entry = self.stack[-1]
1622 entry["need_result_map_for_nested"] = True
1623 else:
1624 entry = None
1625 yield self._result_columns, self._ordered_columns
1626 finally:
1627 if entry:
1628 entry.pop("need_result_map_for_nested")
1629 self._result_columns, self._ordered_columns = (
1630 result_columns,
1631 ordered_columns,
1632 )
1633
1634 def _process_positional(self):
1635 assert not self.positiontup
1636 assert self.state is CompilerState.STRING_APPLIED
1637 assert not self._numeric_binds
1638
1639 if self.dialect.paramstyle == "format":
1640 placeholder = "%s"
1641 else:
1642 assert self.dialect.paramstyle == "qmark"
1643 placeholder = "?"
1644
1645 positions = []
1646
1647 def find_position(m: re.Match[str]) -> str:
1648 normal_bind = m.group(1)
1649 if normal_bind:
1650 positions.append(normal_bind)
1651 return placeholder
1652 else:
1653 # this a post-compile bind
1654 positions.append(m.group(2))
1655 return m.group(0)
1656
1657 self.string = re.sub(
1658 self._positional_pattern, find_position, self.string
1659 )
1660
1661 if self.escaped_bind_names:
1662 reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
1663 assert len(self.escaped_bind_names) == len(reverse_escape)
1664 self.positiontup = [
1665 reverse_escape.get(name, name) for name in positions
1666 ]
1667 else:
1668 self.positiontup = positions
1669
1670 if self._insertmanyvalues:
1671 positions = []
1672
1673 single_values_expr = re.sub(
1674 self._positional_pattern,
1675 find_position,
1676 self._insertmanyvalues.single_values_expr,
1677 )
1678 insert_crud_params = [
1679 (
1680 v[0],
1681 v[1],
1682 re.sub(self._positional_pattern, find_position, v[2]),
1683 v[3],
1684 )
1685 for v in self._insertmanyvalues.insert_crud_params
1686 ]
1687
1688 self._insertmanyvalues = self._insertmanyvalues._replace(
1689 single_values_expr=single_values_expr,
1690 insert_crud_params=insert_crud_params,
1691 )
1692
1693 def _process_numeric(self):
1694 assert self._numeric_binds
1695 assert self.state is CompilerState.STRING_APPLIED
1696
1697 num = 1
1698 param_pos: Dict[str, str] = {}
1699 order: Iterable[str]
1700 if self._insertmanyvalues and self._values_bindparam is not None:
1701 # bindparams that are not in values are always placed first.
1702 # this avoids the need of changing them when using executemany
1703 # values () ()
1704 order = itertools.chain(
1705 (
1706 name
1707 for name in self.bind_names.values()
1708 if name not in self._values_bindparam
1709 ),
1710 self.bind_names.values(),
1711 )
1712 else:
1713 order = self.bind_names.values()
1714
1715 for bind_name in order:
1716 if bind_name in param_pos:
1717 continue
1718 bind = self.binds[bind_name]
1719 if (
1720 bind in self.post_compile_params
1721 or bind in self.literal_execute_params
1722 ):
1723 # set to None to just mark the in positiontup, it will not
1724 # be replaced below.
1725 param_pos[bind_name] = None # type: ignore
1726 else:
1727 ph = f"{self._numeric_binds_identifier_char}{num}"
1728 num += 1
1729 param_pos[bind_name] = ph
1730
1731 self.next_numeric_pos = num
1732
1733 self.positiontup = list(param_pos)
1734 if self.escaped_bind_names:
1735 len_before = len(param_pos)
1736 param_pos = {
1737 self.escaped_bind_names.get(name, name): pos
1738 for name, pos in param_pos.items()
1739 }
1740 assert len(param_pos) == len_before
1741
1742 # Can't use format here since % chars are not escaped.
1743 self.string = self._pyformat_pattern.sub(
1744 lambda m: param_pos[m.group(1)], self.string
1745 )
1746
1747 if self._insertmanyvalues:
1748 single_values_expr = (
1749 # format is ok here since single_values_expr includes only
1750 # place-holders
1751 self._insertmanyvalues.single_values_expr
1752 % param_pos
1753 )
1754 insert_crud_params = [
1755 (v[0], v[1], "%s", v[3])
1756 for v in self._insertmanyvalues.insert_crud_params
1757 ]
1758
1759 self._insertmanyvalues = self._insertmanyvalues._replace(
1760 # This has the numbers (:1, :2)
1761 single_values_expr=single_values_expr,
1762 # The single binds are instead %s so they can be formatted
1763 insert_crud_params=insert_crud_params,
1764 )
1765
1766 @util.memoized_property
1767 def _bind_processors(
1768 self,
1769 ) -> MutableMapping[
1770 str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]]
1771 ]:
1772 # mypy is not able to see the two value types as the above Union,
1773 # it just sees "object". don't know how to resolve
1774 return {
1775 key: value # type: ignore
1776 for key, value in (
1777 (
1778 self.bind_names[bindparam],
1779 (
1780 bindparam.type._cached_bind_processor(self.dialect)
1781 if not bindparam.type._is_tuple_type
1782 else tuple(
1783 elem_type._cached_bind_processor(self.dialect)
1784 for elem_type in cast(
1785 TupleType, bindparam.type
1786 ).types
1787 )
1788 ),
1789 )
1790 for bindparam in self.bind_names
1791 )
1792 if value is not None
1793 }
1794
1795 def is_subquery(self):
1796 return len(self.stack) > 1
1797
1798 @property
1799 def sql_compiler(self) -> Self:
1800 return self
1801
1802 def construct_expanded_state(
1803 self,
1804 params: Optional[_CoreSingleExecuteParams] = None,
1805 escape_names: bool = True,
1806 ) -> ExpandedState:
1807 """Return a new :class:`.ExpandedState` for a given parameter set.
1808
1809 For queries that use "expanding" or other late-rendered parameters,
1810 this method will provide for both the finalized SQL string as well
1811 as the parameters that would be used for a particular parameter set.
1812
1813 .. versionadded:: 2.0.0rc1
1814
1815 """
1816 parameters = self.construct_params(
1817 params,
1818 escape_names=escape_names,
1819 _no_postcompile=True,
1820 )
1821 return self._process_parameters_for_postcompile(
1822 parameters,
1823 )
1824
1825 def construct_params(
1826 self,
1827 params: Optional[_CoreSingleExecuteParams] = None,
1828 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
1829 escape_names: bool = True,
1830 _group_number: Optional[int] = None,
1831 _check: bool = True,
1832 _no_postcompile: bool = False,
1833 ) -> _MutableCoreSingleExecuteParams:
1834 """return a dictionary of bind parameter keys and values"""
1835
1836 if self._render_postcompile and not _no_postcompile:
1837 assert self._post_compile_expanded_state is not None
1838 if not params:
1839 return dict(self._post_compile_expanded_state.parameters)
1840 else:
1841 raise exc.InvalidRequestError(
1842 "can't construct new parameters when render_postcompile "
1843 "is used; the statement is hard-linked to the original "
1844 "parameters. Use construct_expanded_state to generate a "
1845 "new statement and parameters."
1846 )
1847
1848 has_escaped_names = escape_names and bool(self.escaped_bind_names)
1849
1850 if extracted_parameters:
1851 # related the bound parameters collected in the original cache key
1852 # to those collected in the incoming cache key. They will not have
1853 # matching names but they will line up positionally in the same
1854 # way. The parameters present in self.bind_names may be clones of
1855 # these original cache key params in the case of DML but the .key
1856 # will be guaranteed to match.
1857 if self.cache_key is None:
1858 raise exc.CompileError(
1859 "This compiled object has no original cache key; "
1860 "can't pass extracted_parameters to construct_params"
1861 )
1862 else:
1863 orig_extracted = self.cache_key[1]
1864
1865 ckbm_tuple = self._cache_key_bind_match
1866 assert ckbm_tuple is not None
1867 ckbm, _ = ckbm_tuple
1868 resolved_extracted = {
1869 bind: extracted
1870 for b, extracted in zip(orig_extracted, extracted_parameters)
1871 for bind in ckbm[b]
1872 }
1873 else:
1874 resolved_extracted = None
1875
1876 if params:
1877 pd = {}
1878 for bindparam, name in self.bind_names.items():
1879 escaped_name = (
1880 self.escaped_bind_names.get(name, name)
1881 if has_escaped_names
1882 else name
1883 )
1884
1885 if bindparam.key in params:
1886 pd[escaped_name] = params[bindparam.key]
1887 elif name in params:
1888 pd[escaped_name] = params[name]
1889
1890 elif _check and bindparam.required:
1891 if _group_number:
1892 raise exc.InvalidRequestError(
1893 "A value is required for bind parameter %r, "
1894 "in parameter group %d"
1895 % (bindparam.key, _group_number),
1896 code="cd3x",
1897 )
1898 else:
1899 raise exc.InvalidRequestError(
1900 "A value is required for bind parameter %r"
1901 % bindparam.key,
1902 code="cd3x",
1903 )
1904 else:
1905 if resolved_extracted:
1906 value_param = resolved_extracted.get(
1907 bindparam, bindparam
1908 )
1909 else:
1910 value_param = bindparam
1911
1912 if bindparam.callable:
1913 pd[escaped_name] = value_param.effective_value
1914 else:
1915 pd[escaped_name] = value_param.value
1916 return pd
1917 else:
1918 pd = {}
1919 for bindparam, name in self.bind_names.items():
1920 escaped_name = (
1921 self.escaped_bind_names.get(name, name)
1922 if has_escaped_names
1923 else name
1924 )
1925
1926 if _check and bindparam.required:
1927 if _group_number:
1928 raise exc.InvalidRequestError(
1929 "A value is required for bind parameter %r, "
1930 "in parameter group %d"
1931 % (bindparam.key, _group_number),
1932 code="cd3x",
1933 )
1934 else:
1935 raise exc.InvalidRequestError(
1936 "A value is required for bind parameter %r"
1937 % bindparam.key,
1938 code="cd3x",
1939 )
1940
1941 if resolved_extracted:
1942 value_param = resolved_extracted.get(bindparam, bindparam)
1943 else:
1944 value_param = bindparam
1945
1946 if bindparam.callable:
1947 pd[escaped_name] = value_param.effective_value
1948 else:
1949 pd[escaped_name] = value_param.value
1950
1951 return pd
1952
1953 @util.memoized_instancemethod
1954 def _get_set_input_sizes_lookup(self):
1955 dialect = self.dialect
1956
1957 include_types = dialect.include_set_input_sizes
1958 exclude_types = dialect.exclude_set_input_sizes
1959
1960 dbapi = dialect.dbapi
1961
1962 def lookup_type(typ):
1963 dbtype = typ._unwrapped_dialect_impl(dialect).get_dbapi_type(dbapi)
1964
1965 if (
1966 dbtype is not None
1967 and (exclude_types is None or dbtype not in exclude_types)
1968 and (include_types is None or dbtype in include_types)
1969 ):
1970 return dbtype
1971 else:
1972 return None
1973
1974 inputsizes = {}
1975
1976 literal_execute_params = self.literal_execute_params
1977
1978 for bindparam in self.bind_names:
1979 if bindparam in literal_execute_params:
1980 continue
1981
1982 if bindparam.type._is_tuple_type:
1983 inputsizes[bindparam] = [
1984 lookup_type(typ)
1985 for typ in cast(TupleType, bindparam.type).types
1986 ]
1987 else:
1988 inputsizes[bindparam] = lookup_type(bindparam.type)
1989
1990 return inputsizes
1991
1992 @property
1993 def params(self):
1994 """Return the bind param dictionary embedded into this
1995 compiled object, for those values that are present.
1996
1997 .. seealso::
1998
1999 :ref:`faq_sql_expression_string` - includes a usage example for
2000 debugging use cases.
2001
2002 """
2003 return self.construct_params(_check=False)
2004
2005 def _process_parameters_for_postcompile(
2006 self,
2007 parameters: _MutableCoreSingleExecuteParams,
2008 _populate_self: bool = False,
2009 ) -> ExpandedState:
2010 """handle special post compile parameters.
2011
2012 These include:
2013
2014 * "expanding" parameters -typically IN tuples that are rendered
2015 on a per-parameter basis for an otherwise fixed SQL statement string.
2016
2017 * literal_binds compiled with the literal_execute flag. Used for
2018 things like SQL Server "TOP N" where the driver does not accommodate
2019 N as a bound parameter.
2020
2021 """
2022
2023 expanded_parameters = {}
2024 new_positiontup: Optional[List[str]]
2025
2026 pre_expanded_string = self._pre_expanded_string
2027 if pre_expanded_string is None:
2028 pre_expanded_string = self.string
2029
2030 if self.positional:
2031 new_positiontup = []
2032
2033 pre_expanded_positiontup = self._pre_expanded_positiontup
2034 if pre_expanded_positiontup is None:
2035 pre_expanded_positiontup = self.positiontup
2036
2037 else:
2038 new_positiontup = pre_expanded_positiontup = None
2039
2040 processors = self._bind_processors
2041 single_processors = cast(
2042 "Mapping[str, _BindProcessorType[Any]]", processors
2043 )
2044 tuple_processors = cast(
2045 "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors
2046 )
2047
2048 new_processors: Dict[str, _BindProcessorType[Any]] = {}
2049
2050 replacement_expressions: Dict[str, Any] = {}
2051 to_update_sets: Dict[str, Any] = {}
2052
2053 # notes:
2054 # *unescaped* parameter names in:
2055 # self.bind_names, self.binds, self._bind_processors, self.positiontup
2056 #
2057 # *escaped* parameter names in:
2058 # construct_params(), replacement_expressions
2059
2060 numeric_positiontup: Optional[List[str]] = None
2061
2062 if self.positional and pre_expanded_positiontup is not None:
2063 names: Iterable[str] = pre_expanded_positiontup
2064 if self._numeric_binds:
2065 numeric_positiontup = []
2066 else:
2067 names = self.bind_names.values()
2068
2069 ebn = self.escaped_bind_names
2070 for name in names:
2071 escaped_name = ebn.get(name, name) if ebn else name
2072 parameter = self.binds[name]
2073
2074 if parameter in self.literal_execute_params:
2075 if escaped_name not in replacement_expressions:
2076 replacement_expressions[escaped_name] = (
2077 self.render_literal_bindparam(
2078 parameter,
2079 render_literal_value=parameters.pop(escaped_name),
2080 )
2081 )
2082 continue
2083
2084 if parameter in self.post_compile_params:
2085 if escaped_name in replacement_expressions:
2086 to_update = to_update_sets[escaped_name]
2087 values = None
2088 else:
2089 # we are removing the parameter from parameters
2090 # because it is a list value, which is not expected by
2091 # TypeEngine objects that would otherwise be asked to
2092 # process it. the single name is being replaced with
2093 # individual numbered parameters for each value in the
2094 # param.
2095 #
2096 # note we are also inserting *escaped* parameter names
2097 # into the given dictionary. default dialect will
2098 # use these param names directly as they will not be
2099 # in the escaped_bind_names dictionary.
2100 values = parameters.pop(name)
2101
2102 leep_res = self._literal_execute_expanding_parameter(
2103 escaped_name, parameter, values
2104 )
2105 (to_update, replacement_expr) = leep_res
2106
2107 to_update_sets[escaped_name] = to_update
2108 replacement_expressions[escaped_name] = replacement_expr
2109
2110 if not parameter.literal_execute:
2111 parameters.update(to_update)
2112 if parameter.type._is_tuple_type:
2113 assert values is not None
2114 new_processors.update(
2115 (
2116 "%s_%s_%s" % (name, i, j),
2117 tuple_processors[name][j - 1],
2118 )
2119 for i, tuple_element in enumerate(values, 1)
2120 for j, _ in enumerate(tuple_element, 1)
2121 if name in tuple_processors
2122 and tuple_processors[name][j - 1] is not None
2123 )
2124 else:
2125 new_processors.update(
2126 (key, single_processors[name])
2127 for key, _ in to_update
2128 if name in single_processors
2129 )
2130 if numeric_positiontup is not None:
2131 numeric_positiontup.extend(
2132 name for name, _ in to_update
2133 )
2134 elif new_positiontup is not None:
2135 # to_update has escaped names, but that's ok since
2136 # these are new names, that aren't in the
2137 # escaped_bind_names dict.
2138 new_positiontup.extend(name for name, _ in to_update)
2139 expanded_parameters[name] = [
2140 expand_key for expand_key, _ in to_update
2141 ]
2142 elif new_positiontup is not None:
2143 new_positiontup.append(name)
2144
2145 def process_expanding(m):
2146 key = m.group(1)
2147 expr = replacement_expressions[key]
2148
2149 # if POSTCOMPILE included a bind_expression, render that
2150 # around each element
2151 if m.group(2):
2152 tok = m.group(2).split("~~")
2153 be_left, be_right = tok[1], tok[3]
2154 expr = ", ".join(
2155 "%s%s%s" % (be_left, exp, be_right)
2156 for exp in expr.split(", ")
2157 )
2158 return expr
2159
2160 statement = re.sub(
2161 self._post_compile_pattern, process_expanding, pre_expanded_string
2162 )
2163
2164 if numeric_positiontup is not None:
2165 assert new_positiontup is not None
2166 param_pos = {
2167 key: f"{self._numeric_binds_identifier_char}{num}"
2168 for num, key in enumerate(
2169 numeric_positiontup, self.next_numeric_pos
2170 )
2171 }
2172 # Can't use format here since % chars are not escaped.
2173 statement = self._pyformat_pattern.sub(
2174 lambda m: param_pos[m.group(1)], statement
2175 )
2176 new_positiontup.extend(numeric_positiontup)
2177
2178 expanded_state = ExpandedState(
2179 statement,
2180 parameters,
2181 new_processors,
2182 new_positiontup,
2183 expanded_parameters,
2184 )
2185
2186 if _populate_self:
2187 # this is for the "render_postcompile" flag, which is not
2188 # otherwise used internally and is for end-user debugging and
2189 # special use cases.
2190 self._pre_expanded_string = pre_expanded_string
2191 self._pre_expanded_positiontup = pre_expanded_positiontup
2192 self.string = expanded_state.statement
2193 self.positiontup = (
2194 list(expanded_state.positiontup or ())
2195 if self.positional
2196 else None
2197 )
2198 self._post_compile_expanded_state = expanded_state
2199
2200 return expanded_state
2201
2202 @util.preload_module("sqlalchemy.engine.cursor")
2203 def _create_result_map(self):
2204 """utility method used for unit tests only."""
2205 cursor = util.preloaded.engine_cursor
2206 return cursor.CursorResultMetaData._create_description_match_map(
2207 self._result_columns
2208 )
2209
2210 # assigned by crud.py for insert/update statements
2211 _get_bind_name_for_col: _BindNameForColProtocol
2212
2213 @util.memoized_property
2214 def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
2215 getter = self._get_bind_name_for_col
2216 return getter
2217
2218 @util.memoized_property
2219 @util.preload_module("sqlalchemy.engine.result")
2220 def _inserted_primary_key_from_lastrowid_getter(self):
2221 result = util.preloaded.engine_result
2222
2223 param_key_getter = self._within_exec_param_key_getter
2224
2225 assert self.compile_state is not None
2226 statement = self.compile_state.statement
2227
2228 if TYPE_CHECKING:
2229 assert isinstance(statement, Insert)
2230
2231 table = statement.table
2232
2233 getters = [
2234 (operator.methodcaller("get", param_key_getter(col), None), col)
2235 for col in table.primary_key
2236 ]
2237
2238 autoinc_getter = None
2239 autoinc_col = table._autoincrement_column
2240 if autoinc_col is not None:
2241 # apply type post processors to the lastrowid
2242 lastrowid_processor = autoinc_col.type._cached_result_processor(
2243 self.dialect, None
2244 )
2245 autoinc_key = param_key_getter(autoinc_col)
2246
2247 # if a bind value is present for the autoincrement column
2248 # in the parameters, we need to do the logic dictated by
2249 # #7998; honor a non-None user-passed parameter over lastrowid.
2250 # previously in the 1.4 series we weren't fetching lastrowid
2251 # at all if the key were present in the parameters
2252 if autoinc_key in self.binds:
2253
2254 def _autoinc_getter(lastrowid, parameters):
2255 param_value = parameters.get(autoinc_key, lastrowid)
2256 if param_value is not None:
2257 # they supplied non-None parameter, use that.
2258 # SQLite at least is observed to return the wrong
2259 # cursor.lastrowid for INSERT..ON CONFLICT so it
2260 # can't be used in all cases
2261 return param_value
2262 else:
2263 # use lastrowid
2264 return lastrowid
2265
2266 # work around mypy https://github.com/python/mypy/issues/14027
2267 autoinc_getter = _autoinc_getter
2268
2269 else:
2270 lastrowid_processor = None
2271
2272 row_fn = result.result_tuple([col.key for col in table.primary_key])
2273
2274 def get(lastrowid, parameters):
2275 """given cursor.lastrowid value and the parameters used for INSERT,
2276 return a "row" that represents the primary key, either by
2277 using the "lastrowid" or by extracting values from the parameters
2278 that were sent along with the INSERT.
2279
2280 """
2281 if lastrowid_processor is not None:
2282 lastrowid = lastrowid_processor(lastrowid)
2283
2284 if lastrowid is None:
2285 return row_fn(getter(parameters) for getter, col in getters)
2286 else:
2287 return row_fn(
2288 (
2289 (
2290 autoinc_getter(lastrowid, parameters)
2291 if autoinc_getter is not None
2292 else lastrowid
2293 )
2294 if col is autoinc_col
2295 else getter(parameters)
2296 )
2297 for getter, col in getters
2298 )
2299
2300 return get
2301
2302 @util.memoized_property
2303 @util.preload_module("sqlalchemy.engine.result")
2304 def _inserted_primary_key_from_returning_getter(self):
2305 result = util.preloaded.engine_result
2306
2307 assert self.compile_state is not None
2308 statement = self.compile_state.statement
2309
2310 if TYPE_CHECKING:
2311 assert isinstance(statement, Insert)
2312
2313 param_key_getter = self._within_exec_param_key_getter
2314 table = statement.table
2315
2316 returning = self.implicit_returning
2317 assert returning is not None
2318 ret = {col: idx for idx, col in enumerate(returning)}
2319
2320 getters = cast(
2321 "List[Tuple[Callable[[Any], Any], bool]]",
2322 [
2323 (
2324 (operator.itemgetter(ret[col]), True)
2325 if col in ret
2326 else (
2327 operator.methodcaller(
2328 "get", param_key_getter(col), None
2329 ),
2330 False,
2331 )
2332 )
2333 for col in table.primary_key
2334 ],
2335 )
2336
2337 row_fn = result.result_tuple([col.key for col in table.primary_key])
2338
2339 def get(row, parameters):
2340 return row_fn(
2341 getter(row) if use_row else getter(parameters)
2342 for getter, use_row in getters
2343 )
2344
2345 return get
2346
2347 def default_from(self) -> str:
2348 """Called when a SELECT statement has no froms, and no FROM clause is
2349 to be appended.
2350
2351 Gives Oracle Database a chance to tack on a ``FROM DUAL`` to the string
2352 output.
2353
2354 """
2355 return ""
2356
2357 def visit_override_binds(self, override_binds, **kw):
2358 """SQL compile the nested element of an _OverrideBinds with
2359 bindparams swapped out.
2360
2361 The _OverrideBinds is not normally expected to be compiled; it
2362 is meant to be used when an already cached statement is to be used,
2363 the compilation was already performed, and only the bound params should
2364 be swapped in at execution time.
2365
2366 However, there are test cases that exericise this object, and
2367 additionally the ORM subquery loader is known to feed in expressions
2368 which include this construct into new queries (discovered in #11173),
2369 so it has to do the right thing at compile time as well.
2370
2371 """
2372
2373 # get SQL text first
2374 sqltext = override_binds.element._compiler_dispatch(self, **kw)
2375
2376 # for a test compile that is not for caching, change binds after the
2377 # fact. note that we don't try to
2378 # swap the bindparam as we compile, because our element may be
2379 # elsewhere in the statement already (e.g. a subquery or perhaps a
2380 # CTE) and was already visited / compiled. See
2381 # test_relationship_criteria.py ->
2382 # test_selectinload_local_criteria_subquery
2383 for k in override_binds.translate:
2384 if k not in self.binds:
2385 continue
2386 bp = self.binds[k]
2387
2388 # so this would work, just change the value of bp in place.
2389 # but we dont want to mutate things outside.
2390 # bp.value = override_binds.translate[bp.key]
2391 # continue
2392
2393 # instead, need to replace bp with new_bp or otherwise accommodate
2394 # in all internal collections
2395 new_bp = bp._with_value(
2396 override_binds.translate[bp.key],
2397 maintain_key=True,
2398 required=False,
2399 )
2400
2401 name = self.bind_names[bp]
2402 self.binds[k] = self.binds[name] = new_bp
2403 self.bind_names[new_bp] = name
2404 self.bind_names.pop(bp, None)
2405
2406 if bp in self.post_compile_params:
2407 self.post_compile_params |= {new_bp}
2408 if bp in self.literal_execute_params:
2409 self.literal_execute_params |= {new_bp}
2410
2411 ckbm_tuple = self._cache_key_bind_match
2412 if ckbm_tuple:
2413 ckbm, cksm = ckbm_tuple
2414 for bp in bp._cloned_set:
2415 if bp.key in cksm:
2416 cb = cksm[bp.key]
2417 ckbm[cb].append(new_bp)
2418
2419 return sqltext
2420
2421 def visit_grouping(self, grouping, asfrom=False, **kwargs):
2422 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
2423
2424 def visit_select_statement_grouping(self, grouping, **kwargs):
2425 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
2426
2427 def visit_label_reference(
2428 self, element, within_columns_clause=False, **kwargs
2429 ):
2430 if self.stack and self.dialect.supports_simple_order_by_label:
2431 try:
2432 compile_state = cast(
2433 "Union[SelectState, CompoundSelectState]",
2434 self.stack[-1]["compile_state"],
2435 )
2436 except KeyError as ke:
2437 raise exc.CompileError(
2438 "Can't resolve label reference for ORDER BY / "
2439 "GROUP BY / DISTINCT etc."
2440 ) from ke
2441
2442 (
2443 with_cols,
2444 only_froms,
2445 only_cols,
2446 ) = compile_state._label_resolve_dict
2447 if within_columns_clause:
2448 resolve_dict = only_froms
2449 else:
2450 resolve_dict = only_cols
2451
2452 # this can be None in the case that a _label_reference()
2453 # were subject to a replacement operation, in which case
2454 # the replacement of the Label element may have changed
2455 # to something else like a ColumnClause expression.
2456 order_by_elem = element.element._order_by_label_element
2457
2458 if (
2459 order_by_elem is not None
2460 and order_by_elem.name in resolve_dict
2461 and order_by_elem.shares_lineage(
2462 resolve_dict[order_by_elem.name]
2463 )
2464 ):
2465 kwargs["render_label_as_label"] = (
2466 element.element._order_by_label_element
2467 )
2468 return self.process(
2469 element.element,
2470 within_columns_clause=within_columns_clause,
2471 **kwargs,
2472 )
2473
2474 def visit_textual_label_reference(
2475 self, element, within_columns_clause=False, **kwargs
2476 ):
2477 if not self.stack:
2478 # compiling the element outside of the context of a SELECT
2479 return self.process(element._text_clause)
2480
2481 try:
2482 compile_state = cast(
2483 "Union[SelectState, CompoundSelectState]",
2484 self.stack[-1]["compile_state"],
2485 )
2486 except KeyError as ke:
2487 coercions._no_text_coercion(
2488 element.element,
2489 extra=(
2490 "Can't resolve label reference for ORDER BY / "
2491 "GROUP BY / DISTINCT etc."
2492 ),
2493 exc_cls=exc.CompileError,
2494 err=ke,
2495 )
2496
2497 with_cols, only_froms, only_cols = compile_state._label_resolve_dict
2498 try:
2499 if within_columns_clause:
2500 col = only_froms[element.element]
2501 else:
2502 col = with_cols[element.element]
2503 except KeyError as err:
2504 coercions._no_text_coercion(
2505 element.element,
2506 extra=(
2507 "Can't resolve label reference for ORDER BY / "
2508 "GROUP BY / DISTINCT etc."
2509 ),
2510 exc_cls=exc.CompileError,
2511 err=err,
2512 )
2513 else:
2514 kwargs["render_label_as_label"] = col
2515 return self.process(
2516 col, within_columns_clause=within_columns_clause, **kwargs
2517 )
2518
2519 def visit_label(
2520 self,
2521 label,
2522 add_to_result_map=None,
2523 within_label_clause=False,
2524 within_columns_clause=False,
2525 render_label_as_label=None,
2526 result_map_targets=(),
2527 **kw,
2528 ):
2529 # only render labels within the columns clause
2530 # or ORDER BY clause of a select. dialect-specific compilers
2531 # can modify this behavior.
2532 render_label_with_as = (
2533 within_columns_clause and not within_label_clause
2534 )
2535 render_label_only = render_label_as_label is label
2536
2537 if render_label_only or render_label_with_as:
2538 if isinstance(label.name, elements._truncated_label):
2539 labelname = self._truncated_identifier("colident", label.name)
2540 else:
2541 labelname = label.name
2542
2543 if render_label_with_as:
2544 if add_to_result_map is not None:
2545 add_to_result_map(
2546 labelname,
2547 label.name,
2548 (label, labelname) + label._alt_names + result_map_targets,
2549 label.type,
2550 )
2551 return (
2552 label.element._compiler_dispatch(
2553 self,
2554 within_columns_clause=True,
2555 within_label_clause=True,
2556 **kw,
2557 )
2558 + OPERATORS[operators.as_]
2559 + self.preparer.format_label(label, labelname)
2560 )
2561 elif render_label_only:
2562 return self.preparer.format_label(label, labelname)
2563 else:
2564 return label.element._compiler_dispatch(
2565 self, within_columns_clause=False, **kw
2566 )
2567
2568 def _fallback_column_name(self, column):
2569 raise exc.CompileError(
2570 "Cannot compile Column object until its 'name' is assigned."
2571 )
2572
2573 def visit_lambda_element(self, element, **kw):
2574 sql_element = element._resolved
2575 return self.process(sql_element, **kw)
2576
2577 def visit_column(
2578 self,
2579 column: ColumnClause[Any],
2580 add_to_result_map: Optional[_ResultMapAppender] = None,
2581 include_table: bool = True,
2582 result_map_targets: Tuple[Any, ...] = (),
2583 ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
2584 **kwargs: Any,
2585 ) -> str:
2586 name = orig_name = column.name
2587 if name is None:
2588 name = self._fallback_column_name(column)
2589
2590 is_literal = column.is_literal
2591 if not is_literal and isinstance(name, elements._truncated_label):
2592 name = self._truncated_identifier("colident", name)
2593
2594 if add_to_result_map is not None:
2595 targets = (column, name, column.key) + result_map_targets
2596 if column._tq_label:
2597 targets += (column._tq_label,)
2598
2599 add_to_result_map(name, orig_name, targets, column.type)
2600
2601 if is_literal:
2602 # note we are not currently accommodating for
2603 # literal_column(quoted_name('ident', True)) here
2604 name = self.escape_literal_column(name)
2605 else:
2606 name = self.preparer.quote(name)
2607 table = column.table
2608 if table is None or not include_table or not table.named_with_column:
2609 return name
2610 else:
2611 effective_schema = self.preparer.schema_for_object(table)
2612
2613 if effective_schema:
2614 schema_prefix = (
2615 self.preparer.quote_schema(effective_schema) + "."
2616 )
2617 else:
2618 schema_prefix = ""
2619
2620 if TYPE_CHECKING:
2621 assert isinstance(table, NamedFromClause)
2622 tablename = table.name
2623
2624 if (
2625 not effective_schema
2626 and ambiguous_table_name_map
2627 and tablename in ambiguous_table_name_map
2628 ):
2629 tablename = ambiguous_table_name_map[tablename]
2630
2631 if isinstance(tablename, elements._truncated_label):
2632 tablename = self._truncated_identifier("alias", tablename)
2633
2634 return schema_prefix + self.preparer.quote(tablename) + "." + name
2635
2636 def visit_collation(self, element, **kw):
2637 return self.preparer.format_collation(element.collation)
2638
2639 def visit_fromclause(self, fromclause, **kwargs):
2640 return fromclause.name
2641
2642 def visit_index(self, index, **kwargs):
2643 return index.name
2644
2645 def visit_typeclause(self, typeclause, **kw):
2646 kw["type_expression"] = typeclause
2647 kw["identifier_preparer"] = self.preparer
2648 return self.dialect.type_compiler_instance.process(
2649 typeclause.type, **kw
2650 )
2651
2652 def post_process_text(self, text):
2653 if self.preparer._double_percents:
2654 text = text.replace("%", "%%")
2655 return text
2656
2657 def escape_literal_column(self, text):
2658 if self.preparer._double_percents:
2659 text = text.replace("%", "%%")
2660 return text
2661
2662 def visit_textclause(self, textclause, add_to_result_map=None, **kw):
2663 def do_bindparam(m):
2664 name = m.group(1)
2665 if name in textclause._bindparams:
2666 return self.process(textclause._bindparams[name], **kw)
2667 else:
2668 return self.bindparam_string(name, **kw)
2669
2670 if not self.stack:
2671 self.isplaintext = True
2672
2673 if add_to_result_map:
2674 # text() object is present in the columns clause of a
2675 # select(). Add a no-name entry to the result map so that
2676 # row[text()] produces a result
2677 add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
2678
2679 # un-escape any \:params
2680 return BIND_PARAMS_ESC.sub(
2681 lambda m: m.group(1),
2682 BIND_PARAMS.sub(
2683 do_bindparam, self.post_process_text(textclause.text)
2684 ),
2685 )
2686
2687 def visit_textual_select(
2688 self, taf, compound_index=None, asfrom=False, **kw
2689 ):
2690 toplevel = not self.stack
2691 entry = self._default_stack_entry if toplevel else self.stack[-1]
2692
2693 new_entry: _CompilerStackEntry = {
2694 "correlate_froms": set(),
2695 "asfrom_froms": set(),
2696 "selectable": taf,
2697 }
2698 self.stack.append(new_entry)
2699
2700 if taf._independent_ctes:
2701 self._dispatch_independent_ctes(taf, kw)
2702
2703 populate_result_map = (
2704 toplevel
2705 or (
2706 compound_index == 0
2707 and entry.get("need_result_map_for_compound", False)
2708 )
2709 or entry.get("need_result_map_for_nested", False)
2710 )
2711
2712 if populate_result_map:
2713 self._ordered_columns = self._textual_ordered_columns = (
2714 taf.positional
2715 )
2716
2717 # enable looser result column matching when the SQL text links to
2718 # Column objects by name only
2719 self._loose_column_name_matching = not taf.positional and bool(
2720 taf.column_args
2721 )
2722
2723 for c in taf.column_args:
2724 self.process(
2725 c,
2726 within_columns_clause=True,
2727 add_to_result_map=self._add_to_result_map,
2728 )
2729
2730 text = self.process(taf.element, **kw)
2731 if self.ctes:
2732 nesting_level = len(self.stack) if not toplevel else None
2733 text = self._render_cte_clause(nesting_level=nesting_level) + text
2734
2735 self.stack.pop(-1)
2736
2737 return text
2738
2739 def visit_null(self, expr: Null, **kw: Any) -> str:
2740 return "NULL"
2741
2742 def visit_true(self, expr: True_, **kw: Any) -> str:
2743 if self.dialect.supports_native_boolean:
2744 return "true"
2745 else:
2746 return "1"
2747
2748 def visit_false(self, expr: False_, **kw: Any) -> str:
2749 if self.dialect.supports_native_boolean:
2750 return "false"
2751 else:
2752 return "0"
2753
2754 def _generate_delimited_list(self, elements, separator, **kw):
2755 return separator.join(
2756 s
2757 for s in (c._compiler_dispatch(self, **kw) for c in elements)
2758 if s
2759 )
2760
2761 def _generate_delimited_and_list(self, clauses, **kw):
2762 lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean(
2763 operators.and_,
2764 elements.True_._singleton,
2765 elements.False_._singleton,
2766 clauses,
2767 )
2768 if lcc == 1:
2769 return clauses[0]._compiler_dispatch(self, **kw)
2770 else:
2771 separator = OPERATORS[operators.and_]
2772 return separator.join(
2773 s
2774 for s in (c._compiler_dispatch(self, **kw) for c in clauses)
2775 if s
2776 )
2777
2778 def visit_tuple(self, clauselist, **kw):
2779 return "(%s)" % self.visit_clauselist(clauselist, **kw)
2780
2781 def visit_element_list(self, element, **kw):
2782 return self._generate_delimited_list(element.clauses, " ", **kw)
2783
2784 def visit_order_by_list(self, element, **kw):
2785 return self._generate_delimited_list(element.clauses, ", ", **kw)
2786
2787 def visit_clauselist(self, clauselist, **kw):
2788 sep = clauselist.operator
2789 if sep is None:
2790 sep = " "
2791 else:
2792 sep = OPERATORS[clauselist.operator]
2793
2794 return self._generate_delimited_list(clauselist.clauses, sep, **kw)
2795
2796 def visit_expression_clauselist(self, clauselist, **kw):
2797 operator_ = clauselist.operator
2798
2799 disp = self._get_operator_dispatch(
2800 operator_, "expression_clauselist", None
2801 )
2802 if disp:
2803 return disp(clauselist, operator_, **kw)
2804
2805 try:
2806 opstring = OPERATORS[operator_]
2807 except KeyError as err:
2808 raise exc.UnsupportedCompilationError(self, operator_) from err
2809 else:
2810 kw["_in_operator_expression"] = True
2811 return self._generate_delimited_list(
2812 clauselist.clauses, opstring, **kw
2813 )
2814
2815 def visit_case(self, clause, **kwargs):
2816 x = "CASE "
2817 if clause.value is not None:
2818 x += clause.value._compiler_dispatch(self, **kwargs) + " "
2819 for cond, result in clause.whens:
2820 x += (
2821 "WHEN "
2822 + cond._compiler_dispatch(self, **kwargs)
2823 + " THEN "
2824 + result._compiler_dispatch(self, **kwargs)
2825 + " "
2826 )
2827 if clause.else_ is not None:
2828 x += (
2829 "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
2830 )
2831 x += "END"
2832 return x
2833
2834 def visit_type_coerce(self, type_coerce, **kw):
2835 return type_coerce.typed_expression._compiler_dispatch(self, **kw)
2836
2837 def visit_cast(self, cast, **kwargs):
2838 type_clause = cast.typeclause._compiler_dispatch(self, **kwargs)
2839 match = re.match("(.*)( COLLATE .*)", type_clause)
2840 return "CAST(%s AS %s)%s" % (
2841 cast.clause._compiler_dispatch(self, **kwargs),
2842 match.group(1) if match else type_clause,
2843 match.group(2) if match else "",
2844 )
2845
2846 def visit_frame_clause(self, frameclause, **kw):
2847
2848 if frameclause.lower_type is elements._FrameClauseType.RANGE_UNBOUNDED:
2849 left = "UNBOUNDED PRECEDING"
2850 elif frameclause.lower_type is elements._FrameClauseType.RANGE_CURRENT:
2851 left = "CURRENT ROW"
2852 else:
2853 val = self.process(frameclause.lower_integer_bind, **kw)
2854 if (
2855 frameclause.lower_type
2856 is elements._FrameClauseType.RANGE_PRECEDING
2857 ):
2858 left = f"{val} PRECEDING"
2859 else:
2860 left = f"{val} FOLLOWING"
2861
2862 if frameclause.upper_type is elements._FrameClauseType.RANGE_UNBOUNDED:
2863 right = "UNBOUNDED FOLLOWING"
2864 elif frameclause.upper_type is elements._FrameClauseType.RANGE_CURRENT:
2865 right = "CURRENT ROW"
2866 else:
2867 val = self.process(frameclause.upper_integer_bind, **kw)
2868 if (
2869 frameclause.upper_type
2870 is elements._FrameClauseType.RANGE_PRECEDING
2871 ):
2872 right = f"{val} PRECEDING"
2873 else:
2874 right = f"{val} FOLLOWING"
2875
2876 return f"{left} AND {right}"
2877
2878 def visit_over(self, over, **kwargs):
2879 text = over.element._compiler_dispatch(self, **kwargs)
2880 if over.range_ is not None:
2881 range_ = f"RANGE BETWEEN {self.process(over.range_, **kwargs)}"
2882 elif over.rows is not None:
2883 range_ = f"ROWS BETWEEN {self.process(over.rows, **kwargs)}"
2884 elif over.groups is not None:
2885 range_ = f"GROUPS BETWEEN {self.process(over.groups, **kwargs)}"
2886 else:
2887 range_ = None
2888
2889 return "%s OVER (%s)" % (
2890 text,
2891 " ".join(
2892 [
2893 "%s BY %s"
2894 % (word, clause._compiler_dispatch(self, **kwargs))
2895 for word, clause in (
2896 ("PARTITION", over.partition_by),
2897 ("ORDER", over.order_by),
2898 )
2899 if clause is not None and len(clause)
2900 ]
2901 + ([range_] if range_ else [])
2902 ),
2903 )
2904
2905 def visit_withingroup(self, withingroup, **kwargs):
2906 return "%s WITHIN GROUP (ORDER BY %s)" % (
2907 withingroup.element._compiler_dispatch(self, **kwargs),
2908 withingroup.order_by._compiler_dispatch(self, **kwargs),
2909 )
2910
2911 def visit_funcfilter(self, funcfilter, **kwargs):
2912 return "%s FILTER (WHERE %s)" % (
2913 funcfilter.func._compiler_dispatch(self, **kwargs),
2914 funcfilter.criterion._compiler_dispatch(self, **kwargs),
2915 )
2916
2917 def visit_extract(self, extract, **kwargs):
2918 field = self.extract_map.get(extract.field, extract.field)
2919 return "EXTRACT(%s FROM %s)" % (
2920 field,
2921 extract.expr._compiler_dispatch(self, **kwargs),
2922 )
2923
2924 def visit_scalar_function_column(self, element, **kw):
2925 compiled_fn = self.visit_function(element.fn, **kw)
2926 compiled_col = self.visit_column(element, **kw)
2927 return "(%s).%s" % (compiled_fn, compiled_col)
2928
2929 def visit_function(
2930 self,
2931 func: Function[Any],
2932 add_to_result_map: Optional[_ResultMapAppender] = None,
2933 **kwargs: Any,
2934 ) -> str:
2935 if add_to_result_map is not None:
2936 add_to_result_map(func.name, func.name, (func.name,), func.type)
2937
2938 disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
2939
2940 text: str
2941
2942 if disp:
2943 text = disp(func, **kwargs)
2944 else:
2945 name = FUNCTIONS.get(func._deannotate().__class__, None)
2946 if name:
2947 if func._has_args:
2948 name += "%(expr)s"
2949 else:
2950 name = func.name
2951 name = (
2952 self.preparer.quote(name)
2953 if self.preparer._requires_quotes_illegal_chars(name)
2954 or isinstance(name, elements.quoted_name)
2955 else name
2956 )
2957 name = name + "%(expr)s"
2958 text = ".".join(
2959 [
2960 (
2961 self.preparer.quote(tok)
2962 if self.preparer._requires_quotes_illegal_chars(tok)
2963 or isinstance(name, elements.quoted_name)
2964 else tok
2965 )
2966 for tok in func.packagenames
2967 ]
2968 + [name]
2969 ) % {"expr": self.function_argspec(func, **kwargs)}
2970
2971 if func._with_ordinality:
2972 text += " WITH ORDINALITY"
2973 return text
2974
2975 def visit_next_value_func(self, next_value, **kw):
2976 return self.visit_sequence(next_value.sequence)
2977
2978 def visit_sequence(self, sequence, **kw):
2979 raise NotImplementedError(
2980 "Dialect '%s' does not support sequence increments."
2981 % self.dialect.name
2982 )
2983
2984 def function_argspec(self, func: Function[Any], **kwargs: Any) -> str:
2985 return func.clause_expr._compiler_dispatch(self, **kwargs)
2986
2987 def visit_compound_select(
2988 self, cs, asfrom=False, compound_index=None, **kwargs
2989 ):
2990 toplevel = not self.stack
2991
2992 compile_state = cs._compile_state_factory(cs, self, **kwargs)
2993
2994 if toplevel and not self.compile_state:
2995 self.compile_state = compile_state
2996
2997 compound_stmt = compile_state.statement
2998
2999 entry = self._default_stack_entry if toplevel else self.stack[-1]
3000 need_result_map = toplevel or (
3001 not compound_index
3002 and entry.get("need_result_map_for_compound", False)
3003 )
3004
3005 # indicates there is already a CompoundSelect in play
3006 if compound_index == 0:
3007 entry["select_0"] = cs
3008
3009 self.stack.append(
3010 {
3011 "correlate_froms": entry["correlate_froms"],
3012 "asfrom_froms": entry["asfrom_froms"],
3013 "selectable": cs,
3014 "compile_state": compile_state,
3015 "need_result_map_for_compound": need_result_map,
3016 }
3017 )
3018
3019 if compound_stmt._independent_ctes:
3020 self._dispatch_independent_ctes(compound_stmt, kwargs)
3021
3022 keyword = self.compound_keywords[cs.keyword]
3023
3024 text = (" " + keyword + " ").join(
3025 (
3026 c._compiler_dispatch(
3027 self, asfrom=asfrom, compound_index=i, **kwargs
3028 )
3029 for i, c in enumerate(cs.selects)
3030 )
3031 )
3032
3033 kwargs["include_table"] = False
3034 text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
3035 text += self.order_by_clause(cs, **kwargs)
3036 if cs._has_row_limiting_clause:
3037 text += self._row_limit_clause(cs, **kwargs)
3038
3039 if self.ctes:
3040 nesting_level = len(self.stack) if not toplevel else None
3041 text = (
3042 self._render_cte_clause(
3043 nesting_level=nesting_level,
3044 include_following_stack=True,
3045 )
3046 + text
3047 )
3048
3049 self.stack.pop(-1)
3050 return text
3051
3052 def _row_limit_clause(self, cs, **kwargs):
3053 if cs._fetch_clause is not None:
3054 return self.fetch_clause(cs, **kwargs)
3055 else:
3056 return self.limit_clause(cs, **kwargs)
3057
3058 def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
3059 attrname = "visit_%s_%s%s" % (
3060 operator_.__name__,
3061 qualifier1,
3062 "_" + qualifier2 if qualifier2 else "",
3063 )
3064 return getattr(self, attrname, None)
3065
3066 def visit_unary(
3067 self, unary, add_to_result_map=None, result_map_targets=(), **kw
3068 ):
3069 if add_to_result_map is not None:
3070 result_map_targets += (unary,)
3071 kw["add_to_result_map"] = add_to_result_map
3072 kw["result_map_targets"] = result_map_targets
3073
3074 if unary.operator:
3075 if unary.modifier:
3076 raise exc.CompileError(
3077 "Unary expression does not support operator "
3078 "and modifier simultaneously"
3079 )
3080 disp = self._get_operator_dispatch(
3081 unary.operator, "unary", "operator"
3082 )
3083 if disp:
3084 return disp(unary, unary.operator, **kw)
3085 else:
3086 return self._generate_generic_unary_operator(
3087 unary, OPERATORS[unary.operator], **kw
3088 )
3089 elif unary.modifier:
3090 disp = self._get_operator_dispatch(
3091 unary.modifier, "unary", "modifier"
3092 )
3093 if disp:
3094 return disp(unary, unary.modifier, **kw)
3095 else:
3096 return self._generate_generic_unary_modifier(
3097 unary, OPERATORS[unary.modifier], **kw
3098 )
3099 else:
3100 raise exc.CompileError(
3101 "Unary expression has no operator or modifier"
3102 )
3103
3104 def visit_truediv_binary(self, binary, operator, **kw):
3105 if self.dialect.div_is_floordiv:
3106 return (
3107 self.process(binary.left, **kw)
3108 + " / "
3109 # TODO: would need a fast cast again here,
3110 # unless we want to use an implicit cast like "+ 0.0"
3111 + self.process(
3112 elements.Cast(
3113 binary.right,
3114 (
3115 binary.right.type
3116 if binary.right.type._type_affinity
3117 in (sqltypes.Numeric, sqltypes.Float)
3118 else sqltypes.Numeric()
3119 ),
3120 ),
3121 **kw,
3122 )
3123 )
3124 else:
3125 return (
3126 self.process(binary.left, **kw)
3127 + " / "
3128 + self.process(binary.right, **kw)
3129 )
3130
3131 def visit_floordiv_binary(self, binary, operator, **kw):
3132 if (
3133 self.dialect.div_is_floordiv
3134 and binary.right.type._type_affinity is sqltypes.Integer
3135 ):
3136 return (
3137 self.process(binary.left, **kw)
3138 + " / "
3139 + self.process(binary.right, **kw)
3140 )
3141 else:
3142 return "FLOOR(%s)" % (
3143 self.process(binary.left, **kw)
3144 + " / "
3145 + self.process(binary.right, **kw)
3146 )
3147
3148 def visit_is_true_unary_operator(self, element, operator, **kw):
3149 if (
3150 element._is_implicitly_boolean
3151 or self.dialect.supports_native_boolean
3152 ):
3153 return self.process(element.element, **kw)
3154 else:
3155 return "%s = 1" % self.process(element.element, **kw)
3156
3157 def visit_is_false_unary_operator(self, element, operator, **kw):
3158 if (
3159 element._is_implicitly_boolean
3160 or self.dialect.supports_native_boolean
3161 ):
3162 return "NOT %s" % self.process(element.element, **kw)
3163 else:
3164 return "%s = 0" % self.process(element.element, **kw)
3165
3166 def visit_not_match_op_binary(self, binary, operator, **kw):
3167 return "NOT %s" % self.visit_binary(
3168 binary, override_operator=operators.match_op
3169 )
3170
3171 def visit_not_in_op_binary(self, binary, operator, **kw):
3172 # The brackets are required in the NOT IN operation because the empty
3173 # case is handled using the form "(col NOT IN (null) OR 1 = 1)".
3174 # The presence of the OR makes the brackets required.
3175 return "(%s)" % self._generate_generic_binary(
3176 binary, OPERATORS[operator], **kw
3177 )
3178
3179 def visit_empty_set_op_expr(self, type_, expand_op, **kw):
3180 if expand_op is operators.not_in_op:
3181 if len(type_) > 1:
3182 return "(%s)) OR (1 = 1" % (
3183 ", ".join("NULL" for element in type_)
3184 )
3185 else:
3186 return "NULL) OR (1 = 1"
3187 elif expand_op is operators.in_op:
3188 if len(type_) > 1:
3189 return "(%s)) AND (1 != 1" % (
3190 ", ".join("NULL" for element in type_)
3191 )
3192 else:
3193 return "NULL) AND (1 != 1"
3194 else:
3195 return self.visit_empty_set_expr(type_)
3196
3197 def visit_empty_set_expr(self, element_types, **kw):
3198 raise NotImplementedError(
3199 "Dialect '%s' does not support empty set expression."
3200 % self.dialect.name
3201 )
3202
3203 def _literal_execute_expanding_parameter_literal_binds(
3204 self, parameter, values, bind_expression_template=None
3205 ):
3206 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
3207
3208 if not values:
3209 # empty IN expression. note we don't need to use
3210 # bind_expression_template here because there are no
3211 # expressions to render.
3212
3213 if typ_dialect_impl._is_tuple_type:
3214 replacement_expression = (
3215 "VALUES " if self.dialect.tuple_in_values else ""
3216 ) + self.visit_empty_set_op_expr(
3217 parameter.type.types, parameter.expand_op
3218 )
3219
3220 else:
3221 replacement_expression = self.visit_empty_set_op_expr(
3222 [parameter.type], parameter.expand_op
3223 )
3224
3225 elif typ_dialect_impl._is_tuple_type or (
3226 typ_dialect_impl._isnull
3227 and isinstance(values[0], collections_abc.Sequence)
3228 and not isinstance(values[0], (str, bytes))
3229 ):
3230 if typ_dialect_impl._has_bind_expression:
3231 raise NotImplementedError(
3232 "bind_expression() on TupleType not supported with "
3233 "literal_binds"
3234 )
3235
3236 replacement_expression = (
3237 "VALUES " if self.dialect.tuple_in_values else ""
3238 ) + ", ".join(
3239 "(%s)"
3240 % (
3241 ", ".join(
3242 self.render_literal_value(value, param_type)
3243 for value, param_type in zip(
3244 tuple_element, parameter.type.types
3245 )
3246 )
3247 )
3248 for i, tuple_element in enumerate(values)
3249 )
3250 else:
3251 if bind_expression_template:
3252 post_compile_pattern = self._post_compile_pattern
3253 m = post_compile_pattern.search(bind_expression_template)
3254 assert m and m.group(
3255 2
3256 ), "unexpected format for expanding parameter"
3257
3258 tok = m.group(2).split("~~")
3259 be_left, be_right = tok[1], tok[3]
3260 replacement_expression = ", ".join(
3261 "%s%s%s"
3262 % (
3263 be_left,
3264 self.render_literal_value(value, parameter.type),
3265 be_right,
3266 )
3267 for value in values
3268 )
3269 else:
3270 replacement_expression = ", ".join(
3271 self.render_literal_value(value, parameter.type)
3272 for value in values
3273 )
3274
3275 return (), replacement_expression
3276
3277 def _literal_execute_expanding_parameter(self, name, parameter, values):
3278 if parameter.literal_execute:
3279 return self._literal_execute_expanding_parameter_literal_binds(
3280 parameter, values
3281 )
3282
3283 dialect = self.dialect
3284 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect)
3285
3286 if self._numeric_binds:
3287 bind_template = self.compilation_bindtemplate
3288 else:
3289 bind_template = self.bindtemplate
3290
3291 if (
3292 self.dialect._bind_typing_render_casts
3293 and typ_dialect_impl.render_bind_cast
3294 ):
3295
3296 def _render_bindtemplate(name):
3297 return self.render_bind_cast(
3298 parameter.type,
3299 typ_dialect_impl,
3300 bind_template % {"name": name},
3301 )
3302
3303 else:
3304
3305 def _render_bindtemplate(name):
3306 return bind_template % {"name": name}
3307
3308 if not values:
3309 to_update = []
3310 if typ_dialect_impl._is_tuple_type:
3311 replacement_expression = self.visit_empty_set_op_expr(
3312 parameter.type.types, parameter.expand_op
3313 )
3314 else:
3315 replacement_expression = self.visit_empty_set_op_expr(
3316 [parameter.type], parameter.expand_op
3317 )
3318
3319 elif typ_dialect_impl._is_tuple_type or (
3320 typ_dialect_impl._isnull
3321 and isinstance(values[0], collections_abc.Sequence)
3322 and not isinstance(values[0], (str, bytes))
3323 ):
3324 assert not typ_dialect_impl._is_array
3325 to_update = [
3326 ("%s_%s_%s" % (name, i, j), value)
3327 for i, tuple_element in enumerate(values, 1)
3328 for j, value in enumerate(tuple_element, 1)
3329 ]
3330
3331 replacement_expression = (
3332 "VALUES " if dialect.tuple_in_values else ""
3333 ) + ", ".join(
3334 "(%s)"
3335 % (
3336 ", ".join(
3337 _render_bindtemplate(
3338 to_update[i * len(tuple_element) + j][0]
3339 )
3340 for j, value in enumerate(tuple_element)
3341 )
3342 )
3343 for i, tuple_element in enumerate(values)
3344 )
3345 else:
3346 to_update = [
3347 ("%s_%s" % (name, i), value)
3348 for i, value in enumerate(values, 1)
3349 ]
3350 replacement_expression = ", ".join(
3351 _render_bindtemplate(key) for key, value in to_update
3352 )
3353
3354 return to_update, replacement_expression
3355
3356 def visit_binary(
3357 self,
3358 binary,
3359 override_operator=None,
3360 eager_grouping=False,
3361 from_linter=None,
3362 lateral_from_linter=None,
3363 **kw,
3364 ):
3365 if from_linter and operators.is_comparison(binary.operator):
3366 if lateral_from_linter is not None:
3367 enclosing_lateral = kw["enclosing_lateral"]
3368 lateral_from_linter.edges.update(
3369 itertools.product(
3370 _de_clone(
3371 binary.left._from_objects + [enclosing_lateral]
3372 ),
3373 _de_clone(
3374 binary.right._from_objects + [enclosing_lateral]
3375 ),
3376 )
3377 )
3378 else:
3379 from_linter.edges.update(
3380 itertools.product(
3381 _de_clone(binary.left._from_objects),
3382 _de_clone(binary.right._from_objects),
3383 )
3384 )
3385
3386 # don't allow "? = ?" to render
3387 if (
3388 self.ansi_bind_rules
3389 and isinstance(binary.left, elements.BindParameter)
3390 and isinstance(binary.right, elements.BindParameter)
3391 ):
3392 kw["literal_execute"] = True
3393
3394 operator_ = override_operator or binary.operator
3395 disp = self._get_operator_dispatch(operator_, "binary", None)
3396 if disp:
3397 return disp(binary, operator_, **kw)
3398 else:
3399 try:
3400 opstring = OPERATORS[operator_]
3401 except KeyError as err:
3402 raise exc.UnsupportedCompilationError(self, operator_) from err
3403 else:
3404 return self._generate_generic_binary(
3405 binary,
3406 opstring,
3407 from_linter=from_linter,
3408 lateral_from_linter=lateral_from_linter,
3409 **kw,
3410 )
3411
3412 def visit_function_as_comparison_op_binary(self, element, operator, **kw):
3413 return self.process(element.sql_function, **kw)
3414
3415 def visit_mod_binary(self, binary, operator, **kw):
3416 if self.preparer._double_percents:
3417 return (
3418 self.process(binary.left, **kw)
3419 + " %% "
3420 + self.process(binary.right, **kw)
3421 )
3422 else:
3423 return (
3424 self.process(binary.left, **kw)
3425 + " % "
3426 + self.process(binary.right, **kw)
3427 )
3428
3429 def visit_custom_op_binary(self, element, operator, **kw):
3430 kw["eager_grouping"] = operator.eager_grouping
3431 return self._generate_generic_binary(
3432 element,
3433 " " + self.escape_literal_column(operator.opstring) + " ",
3434 **kw,
3435 )
3436
3437 def visit_custom_op_unary_operator(self, element, operator, **kw):
3438 return self._generate_generic_unary_operator(
3439 element, self.escape_literal_column(operator.opstring) + " ", **kw
3440 )
3441
3442 def visit_custom_op_unary_modifier(self, element, operator, **kw):
3443 return self._generate_generic_unary_modifier(
3444 element, " " + self.escape_literal_column(operator.opstring), **kw
3445 )
3446
3447 def _generate_generic_binary(
3448 self,
3449 binary: BinaryExpression[Any],
3450 opstring: str,
3451 eager_grouping: bool = False,
3452 **kw: Any,
3453 ) -> str:
3454 _in_operator_expression = kw.get("_in_operator_expression", False)
3455
3456 kw["_in_operator_expression"] = True
3457 kw["_binary_op"] = binary.operator
3458 text = (
3459 binary.left._compiler_dispatch(
3460 self, eager_grouping=eager_grouping, **kw
3461 )
3462 + opstring
3463 + binary.right._compiler_dispatch(
3464 self, eager_grouping=eager_grouping, **kw
3465 )
3466 )
3467
3468 if _in_operator_expression and eager_grouping:
3469 text = "(%s)" % text
3470 return text
3471
3472 def _generate_generic_unary_operator(self, unary, opstring, **kw):
3473 return opstring + unary.element._compiler_dispatch(self, **kw)
3474
3475 def _generate_generic_unary_modifier(self, unary, opstring, **kw):
3476 return unary.element._compiler_dispatch(self, **kw) + opstring
3477
3478 @util.memoized_property
3479 def _like_percent_literal(self):
3480 return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
3481
3482 def visit_ilike_case_insensitive_operand(self, element, **kw):
3483 return f"lower({element.element._compiler_dispatch(self, **kw)})"
3484
3485 def visit_contains_op_binary(self, binary, operator, **kw):
3486 binary = binary._clone()
3487 percent = self._like_percent_literal
3488 binary.right = percent.concat(binary.right).concat(percent)
3489 return self.visit_like_op_binary(binary, operator, **kw)
3490
3491 def visit_not_contains_op_binary(self, binary, operator, **kw):
3492 binary = binary._clone()
3493 percent = self._like_percent_literal
3494 binary.right = percent.concat(binary.right).concat(percent)
3495 return self.visit_not_like_op_binary(binary, operator, **kw)
3496
3497 def visit_icontains_op_binary(self, binary, operator, **kw):
3498 binary = binary._clone()
3499 percent = self._like_percent_literal
3500 binary.left = ilike_case_insensitive(binary.left)
3501 binary.right = percent.concat(
3502 ilike_case_insensitive(binary.right)
3503 ).concat(percent)
3504 return self.visit_ilike_op_binary(binary, operator, **kw)
3505
3506 def visit_not_icontains_op_binary(self, binary, operator, **kw):
3507 binary = binary._clone()
3508 percent = self._like_percent_literal
3509 binary.left = ilike_case_insensitive(binary.left)
3510 binary.right = percent.concat(
3511 ilike_case_insensitive(binary.right)
3512 ).concat(percent)
3513 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3514
3515 def visit_startswith_op_binary(self, binary, operator, **kw):
3516 binary = binary._clone()
3517 percent = self._like_percent_literal
3518 binary.right = percent._rconcat(binary.right)
3519 return self.visit_like_op_binary(binary, operator, **kw)
3520
3521 def visit_not_startswith_op_binary(self, binary, operator, **kw):
3522 binary = binary._clone()
3523 percent = self._like_percent_literal
3524 binary.right = percent._rconcat(binary.right)
3525 return self.visit_not_like_op_binary(binary, operator, **kw)
3526
3527 def visit_istartswith_op_binary(self, binary, operator, **kw):
3528 binary = binary._clone()
3529 percent = self._like_percent_literal
3530 binary.left = ilike_case_insensitive(binary.left)
3531 binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
3532 return self.visit_ilike_op_binary(binary, operator, **kw)
3533
3534 def visit_not_istartswith_op_binary(self, binary, operator, **kw):
3535 binary = binary._clone()
3536 percent = self._like_percent_literal
3537 binary.left = ilike_case_insensitive(binary.left)
3538 binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
3539 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3540
3541 def visit_endswith_op_binary(self, binary, operator, **kw):
3542 binary = binary._clone()
3543 percent = self._like_percent_literal
3544 binary.right = percent.concat(binary.right)
3545 return self.visit_like_op_binary(binary, operator, **kw)
3546
3547 def visit_not_endswith_op_binary(self, binary, operator, **kw):
3548 binary = binary._clone()
3549 percent = self._like_percent_literal
3550 binary.right = percent.concat(binary.right)
3551 return self.visit_not_like_op_binary(binary, operator, **kw)
3552
3553 def visit_iendswith_op_binary(self, binary, operator, **kw):
3554 binary = binary._clone()
3555 percent = self._like_percent_literal
3556 binary.left = ilike_case_insensitive(binary.left)
3557 binary.right = percent.concat(ilike_case_insensitive(binary.right))
3558 return self.visit_ilike_op_binary(binary, operator, **kw)
3559
3560 def visit_not_iendswith_op_binary(self, binary, operator, **kw):
3561 binary = binary._clone()
3562 percent = self._like_percent_literal
3563 binary.left = ilike_case_insensitive(binary.left)
3564 binary.right = percent.concat(ilike_case_insensitive(binary.right))
3565 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3566
3567 def visit_like_op_binary(self, binary, operator, **kw):
3568 escape = binary.modifiers.get("escape", None)
3569
3570 return "%s LIKE %s" % (
3571 binary.left._compiler_dispatch(self, **kw),
3572 binary.right._compiler_dispatch(self, **kw),
3573 ) + (
3574 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
3575 if escape is not None
3576 else ""
3577 )
3578
3579 def visit_not_like_op_binary(self, binary, operator, **kw):
3580 escape = binary.modifiers.get("escape", None)
3581 return "%s NOT LIKE %s" % (
3582 binary.left._compiler_dispatch(self, **kw),
3583 binary.right._compiler_dispatch(self, **kw),
3584 ) + (
3585 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
3586 if escape is not None
3587 else ""
3588 )
3589
3590 def visit_ilike_op_binary(self, binary, operator, **kw):
3591 if operator is operators.ilike_op:
3592 binary = binary._clone()
3593 binary.left = ilike_case_insensitive(binary.left)
3594 binary.right = ilike_case_insensitive(binary.right)
3595 # else we assume ilower() has been applied
3596
3597 return self.visit_like_op_binary(binary, operator, **kw)
3598
3599 def visit_not_ilike_op_binary(self, binary, operator, **kw):
3600 if operator is operators.not_ilike_op:
3601 binary = binary._clone()
3602 binary.left = ilike_case_insensitive(binary.left)
3603 binary.right = ilike_case_insensitive(binary.right)
3604 # else we assume ilower() has been applied
3605
3606 return self.visit_not_like_op_binary(binary, operator, **kw)
3607
3608 def visit_between_op_binary(self, binary, operator, **kw):
3609 symmetric = binary.modifiers.get("symmetric", False)
3610 return self._generate_generic_binary(
3611 binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
3612 )
3613
3614 def visit_not_between_op_binary(self, binary, operator, **kw):
3615 symmetric = binary.modifiers.get("symmetric", False)
3616 return self._generate_generic_binary(
3617 binary,
3618 " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
3619 **kw,
3620 )
3621
3622 def visit_regexp_match_op_binary(
3623 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3624 ) -> str:
3625 raise exc.CompileError(
3626 "%s dialect does not support regular expressions"
3627 % self.dialect.name
3628 )
3629
3630 def visit_not_regexp_match_op_binary(
3631 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3632 ) -> str:
3633 raise exc.CompileError(
3634 "%s dialect does not support regular expressions"
3635 % self.dialect.name
3636 )
3637
3638 def visit_regexp_replace_op_binary(
3639 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3640 ) -> str:
3641 raise exc.CompileError(
3642 "%s dialect does not support regular expression replacements"
3643 % self.dialect.name
3644 )
3645
3646 def visit_dmltargetcopy(self, element, *, bindmarkers=None, **kw):
3647 if bindmarkers is None:
3648 raise exc.CompileError(
3649 "DML target objects may only be used with "
3650 "compiled INSERT or UPDATE statements"
3651 )
3652
3653 bindmarkers[element.column.key] = element
3654 return f"__BINDMARKER_~~{element.column.key}~~"
3655
3656 def visit_bindparam(
3657 self,
3658 bindparam,
3659 within_columns_clause=False,
3660 literal_binds=False,
3661 skip_bind_expression=False,
3662 literal_execute=False,
3663 render_postcompile=False,
3664 **kwargs,
3665 ):
3666
3667 if not skip_bind_expression:
3668 impl = bindparam.type.dialect_impl(self.dialect)
3669 if impl._has_bind_expression:
3670 bind_expression = impl.bind_expression(bindparam)
3671 wrapped = self.process(
3672 bind_expression,
3673 skip_bind_expression=True,
3674 within_columns_clause=within_columns_clause,
3675 literal_binds=literal_binds and not bindparam.expanding,
3676 literal_execute=literal_execute,
3677 render_postcompile=render_postcompile,
3678 **kwargs,
3679 )
3680 if bindparam.expanding:
3681 # for postcompile w/ expanding, move the "wrapped" part
3682 # of this into the inside
3683
3684 m = re.match(
3685 r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
3686 )
3687 assert m, "unexpected format for expanding parameter"
3688 wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
3689 m.group(2),
3690 m.group(1),
3691 m.group(3),
3692 )
3693
3694 if literal_binds:
3695 ret = self.render_literal_bindparam(
3696 bindparam,
3697 within_columns_clause=True,
3698 bind_expression_template=wrapped,
3699 **kwargs,
3700 )
3701 return f"({ret})"
3702
3703 return wrapped
3704
3705 if not literal_binds:
3706 literal_execute = (
3707 literal_execute
3708 or bindparam.literal_execute
3709 or (within_columns_clause and self.ansi_bind_rules)
3710 )
3711 post_compile = literal_execute or bindparam.expanding
3712 else:
3713 post_compile = False
3714
3715 if literal_binds:
3716 ret = self.render_literal_bindparam(
3717 bindparam, within_columns_clause=True, **kwargs
3718 )
3719 if bindparam.expanding:
3720 ret = f"({ret})"
3721 return ret
3722
3723 name = self._truncate_bindparam(bindparam)
3724
3725 if name in self.binds:
3726 existing = self.binds[name]
3727 if existing is not bindparam:
3728 if (
3729 (existing.unique or bindparam.unique)
3730 and not existing.proxy_set.intersection(
3731 bindparam.proxy_set
3732 )
3733 and not existing._cloned_set.intersection(
3734 bindparam._cloned_set
3735 )
3736 ):
3737 raise exc.CompileError(
3738 "Bind parameter '%s' conflicts with "
3739 "unique bind parameter of the same name" % name
3740 )
3741 elif existing.expanding != bindparam.expanding:
3742 raise exc.CompileError(
3743 "Can't reuse bound parameter name '%s' in both "
3744 "'expanding' (e.g. within an IN expression) and "
3745 "non-expanding contexts. If this parameter is to "
3746 "receive a list/array value, set 'expanding=True' on "
3747 "it for expressions that aren't IN, otherwise use "
3748 "a different parameter name." % (name,)
3749 )
3750 elif existing._is_crud or bindparam._is_crud:
3751 if existing._is_crud and bindparam._is_crud:
3752 # TODO: this condition is not well understood.
3753 # see tests in test/sql/test_update.py
3754 raise exc.CompileError(
3755 "Encountered unsupported case when compiling an "
3756 "INSERT or UPDATE statement. If this is a "
3757 "multi-table "
3758 "UPDATE statement, please provide string-named "
3759 "arguments to the "
3760 "values() method with distinct names; support for "
3761 "multi-table UPDATE statements that "
3762 "target multiple tables for UPDATE is very "
3763 "limited",
3764 )
3765 else:
3766 raise exc.CompileError(
3767 f"bindparam() name '{bindparam.key}' is reserved "
3768 "for automatic usage in the VALUES or SET "
3769 "clause of this "
3770 "insert/update statement. Please use a "
3771 "name other than column name when using "
3772 "bindparam() "
3773 "with insert() or update() (for example, "
3774 f"'b_{bindparam.key}')."
3775 )
3776
3777 self.binds[bindparam.key] = self.binds[name] = bindparam
3778
3779 # if we are given a cache key that we're going to match against,
3780 # relate the bindparam here to one that is most likely present
3781 # in the "extracted params" portion of the cache key. this is used
3782 # to set up a positional mapping that is used to determine the
3783 # correct parameters for a subsequent use of this compiled with
3784 # a different set of parameter values. here, we accommodate for
3785 # parameters that may have been cloned both before and after the cache
3786 # key was been generated.
3787 ckbm_tuple = self._cache_key_bind_match
3788
3789 if ckbm_tuple:
3790 ckbm, cksm = ckbm_tuple
3791 for bp in bindparam._cloned_set:
3792 if bp.key in cksm:
3793 cb = cksm[bp.key]
3794 ckbm[cb].append(bindparam)
3795
3796 if bindparam.isoutparam:
3797 self.has_out_parameters = True
3798
3799 if post_compile:
3800 if render_postcompile:
3801 self._render_postcompile = True
3802
3803 if literal_execute:
3804 self.literal_execute_params |= {bindparam}
3805 else:
3806 self.post_compile_params |= {bindparam}
3807
3808 ret = self.bindparam_string(
3809 name,
3810 post_compile=post_compile,
3811 expanding=bindparam.expanding,
3812 bindparam_type=bindparam.type,
3813 **kwargs,
3814 )
3815
3816 if bindparam.expanding:
3817 ret = f"({ret})"
3818
3819 return ret
3820
3821 def render_bind_cast(self, type_, dbapi_type, sqltext):
3822 raise NotImplementedError()
3823
3824 def render_literal_bindparam(
3825 self,
3826 bindparam,
3827 render_literal_value=NO_ARG,
3828 bind_expression_template=None,
3829 **kw,
3830 ):
3831 if render_literal_value is not NO_ARG:
3832 value = render_literal_value
3833 else:
3834 if bindparam.value is None and bindparam.callable is None:
3835 op = kw.get("_binary_op", None)
3836 if op and op not in (operators.is_, operators.is_not):
3837 util.warn_limited(
3838 "Bound parameter '%s' rendering literal NULL in a SQL "
3839 "expression; comparisons to NULL should not use "
3840 "operators outside of 'is' or 'is not'",
3841 (bindparam.key,),
3842 )
3843 return self.process(sqltypes.NULLTYPE, **kw)
3844 value = bindparam.effective_value
3845
3846 if bindparam.expanding:
3847 leep = self._literal_execute_expanding_parameter_literal_binds
3848 to_update, replacement_expr = leep(
3849 bindparam,
3850 value,
3851 bind_expression_template=bind_expression_template,
3852 )
3853 return replacement_expr
3854 else:
3855 return self.render_literal_value(value, bindparam.type)
3856
3857 def render_literal_value(
3858 self, value: Any, type_: sqltypes.TypeEngine[Any]
3859 ) -> str:
3860 """Render the value of a bind parameter as a quoted literal.
3861
3862 This is used for statement sections that do not accept bind parameters
3863 on the target driver/database.
3864
3865 This should be implemented by subclasses using the quoting services
3866 of the DBAPI.
3867
3868 """
3869
3870 if value is None and not type_.should_evaluate_none:
3871 # issue #10535 - handle NULL in the compiler without placing
3872 # this onto each type, except for "evaluate None" types
3873 # (e.g. JSON)
3874 return self.process(elements.Null._instance())
3875
3876 processor = type_._cached_literal_processor(self.dialect)
3877 if processor:
3878 try:
3879 return processor(value)
3880 except Exception as e:
3881 raise exc.CompileError(
3882 f"Could not render literal value "
3883 f'"{sql_util._repr_single_value(value)}" '
3884 f"with datatype "
3885 f"{type_}; see parent stack trace for "
3886 "more detail."
3887 ) from e
3888
3889 else:
3890 raise exc.CompileError(
3891 f"No literal value renderer is available for literal value "
3892 f'"{sql_util._repr_single_value(value)}" '
3893 f"with datatype {type_}"
3894 )
3895
3896 def _truncate_bindparam(self, bindparam):
3897 if bindparam in self.bind_names:
3898 return self.bind_names[bindparam]
3899
3900 bind_name = bindparam.key
3901 if isinstance(bind_name, elements._truncated_label):
3902 bind_name = self._truncated_identifier("bindparam", bind_name)
3903
3904 # add to bind_names for translation
3905 self.bind_names[bindparam] = bind_name
3906
3907 return bind_name
3908
3909 def _truncated_identifier(
3910 self, ident_class: str, name: _truncated_label
3911 ) -> str:
3912 if (ident_class, name) in self.truncated_names:
3913 return self.truncated_names[(ident_class, name)]
3914
3915 anonname = name.apply_map(self.anon_map)
3916
3917 if len(anonname) > self.label_length - 6:
3918 counter = self._truncated_counters.get(ident_class, 1)
3919 truncname = (
3920 anonname[0 : max(self.label_length - 6, 0)]
3921 + "_"
3922 + hex(counter)[2:]
3923 )
3924 self._truncated_counters[ident_class] = counter + 1
3925 else:
3926 truncname = anonname
3927 self.truncated_names[(ident_class, name)] = truncname
3928 return truncname
3929
3930 def _anonymize(self, name: str) -> str:
3931 return name % self.anon_map
3932
3933 def bindparam_string(
3934 self,
3935 name: str,
3936 post_compile: bool = False,
3937 expanding: bool = False,
3938 escaped_from: Optional[str] = None,
3939 bindparam_type: Optional[TypeEngine[Any]] = None,
3940 accumulate_bind_names: Optional[Set[str]] = None,
3941 visited_bindparam: Optional[List[str]] = None,
3942 **kw: Any,
3943 ) -> str:
3944 # TODO: accumulate_bind_names is passed by crud.py to gather
3945 # names on a per-value basis, visited_bindparam is passed by
3946 # visit_insert() to collect all parameters in the statement.
3947 # see if this gathering can be simplified somehow
3948 if accumulate_bind_names is not None:
3949 accumulate_bind_names.add(name)
3950 if visited_bindparam is not None:
3951 visited_bindparam.append(name)
3952
3953 if not escaped_from:
3954 if self._bind_translate_re.search(name):
3955 # not quite the translate use case as we want to
3956 # also get a quick boolean if we even found
3957 # unusual characters in the name
3958 new_name = self._bind_translate_re.sub(
3959 lambda m: self._bind_translate_chars[m.group(0)],
3960 name,
3961 )
3962 escaped_from = name
3963 name = new_name
3964
3965 if escaped_from:
3966 self.escaped_bind_names = self.escaped_bind_names.union(
3967 {escaped_from: name}
3968 )
3969 if post_compile:
3970 ret = "__[POSTCOMPILE_%s]" % name
3971 if expanding:
3972 # for expanding, bound parameters or literal values will be
3973 # rendered per item
3974 return ret
3975
3976 # otherwise, for non-expanding "literal execute", apply
3977 # bind casts as determined by the datatype
3978 if bindparam_type is not None:
3979 type_impl = bindparam_type._unwrapped_dialect_impl(
3980 self.dialect
3981 )
3982 if type_impl.render_literal_cast:
3983 ret = self.render_bind_cast(bindparam_type, type_impl, ret)
3984 return ret
3985 elif self.state is CompilerState.COMPILING:
3986 ret = self.compilation_bindtemplate % {"name": name}
3987 else:
3988 ret = self.bindtemplate % {"name": name}
3989
3990 if (
3991 bindparam_type is not None
3992 and self.dialect._bind_typing_render_casts
3993 ):
3994 type_impl = bindparam_type._unwrapped_dialect_impl(self.dialect)
3995 if type_impl.render_bind_cast:
3996 ret = self.render_bind_cast(bindparam_type, type_impl, ret)
3997
3998 return ret
3999
4000 def _dispatch_independent_ctes(self, stmt, kw):
4001 local_kw = kw.copy()
4002 local_kw.pop("cte_opts", None)
4003 for cte, opt in zip(
4004 stmt._independent_ctes, stmt._independent_ctes_opts
4005 ):
4006 cte._compiler_dispatch(self, cte_opts=opt, **local_kw)
4007
4008 def visit_cte(
4009 self,
4010 cte: CTE,
4011 asfrom: bool = False,
4012 ashint: bool = False,
4013 fromhints: Optional[_FromHintsType] = None,
4014 visiting_cte: Optional[CTE] = None,
4015 from_linter: Optional[FromLinter] = None,
4016 cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
4017 **kwargs: Any,
4018 ) -> Optional[str]:
4019 self_ctes = self._init_cte_state()
4020 assert self_ctes is self.ctes
4021
4022 kwargs["visiting_cte"] = cte
4023
4024 cte_name = cte.name
4025
4026 if isinstance(cte_name, elements._truncated_label):
4027 cte_name = self._truncated_identifier("alias", cte_name)
4028
4029 is_new_cte = True
4030 embedded_in_current_named_cte = False
4031
4032 _reference_cte = cte._get_reference_cte()
4033
4034 nesting = cte.nesting or cte_opts.nesting
4035
4036 # check for CTE already encountered
4037 if _reference_cte in self.level_name_by_cte:
4038 cte_level, _, existing_cte_opts = self.level_name_by_cte[
4039 _reference_cte
4040 ]
4041 assert _ == cte_name
4042
4043 cte_level_name = (cte_level, cte_name)
4044 existing_cte = self.ctes_by_level_name[cte_level_name]
4045
4046 # check if we are receiving it here with a specific
4047 # "nest_here" location; if so, move it to this location
4048
4049 if cte_opts.nesting:
4050 if existing_cte_opts.nesting:
4051 raise exc.CompileError(
4052 "CTE is stated as 'nest_here' in "
4053 "more than one location"
4054 )
4055
4056 old_level_name = (cte_level, cte_name)
4057 cte_level = len(self.stack) if nesting else 1
4058 cte_level_name = new_level_name = (cte_level, cte_name)
4059
4060 del self.ctes_by_level_name[old_level_name]
4061 self.ctes_by_level_name[new_level_name] = existing_cte
4062 self.level_name_by_cte[_reference_cte] = new_level_name + (
4063 cte_opts,
4064 )
4065
4066 else:
4067 cte_level = len(self.stack) if nesting else 1
4068 cte_level_name = (cte_level, cte_name)
4069
4070 if cte_level_name in self.ctes_by_level_name:
4071 existing_cte = self.ctes_by_level_name[cte_level_name]
4072 else:
4073 existing_cte = None
4074
4075 if existing_cte is not None:
4076 embedded_in_current_named_cte = visiting_cte is existing_cte
4077
4078 # we've generated a same-named CTE that we are enclosed in,
4079 # or this is the same CTE. just return the name.
4080 if cte is existing_cte._restates or cte is existing_cte:
4081 is_new_cte = False
4082 elif existing_cte is cte._restates:
4083 # we've generated a same-named CTE that is
4084 # enclosed in us - we take precedence, so
4085 # discard the text for the "inner".
4086 del self_ctes[existing_cte]
4087
4088 existing_cte_reference_cte = existing_cte._get_reference_cte()
4089
4090 assert existing_cte_reference_cte is _reference_cte
4091 assert existing_cte_reference_cte is existing_cte
4092
4093 del self.level_name_by_cte[existing_cte_reference_cte]
4094 else:
4095 if (
4096 # if the two CTEs have the same hash, which we expect
4097 # here means that one/both is an annotated of the other
4098 (hash(cte) == hash(existing_cte))
4099 # or...
4100 or (
4101 (
4102 # if they are clones, i.e. they came from the ORM
4103 # or some other visit method
4104 cte._is_clone_of is not None
4105 or existing_cte._is_clone_of is not None
4106 )
4107 # and are deep-copy identical
4108 and cte.compare(existing_cte)
4109 )
4110 ):
4111 # then consider these two CTEs the same
4112 is_new_cte = False
4113 else:
4114 # otherwise these are two CTEs that either will render
4115 # differently, or were indicated separately by the user,
4116 # with the same name
4117 raise exc.CompileError(
4118 "Multiple, unrelated CTEs found with "
4119 "the same name: %r" % cte_name
4120 )
4121
4122 if not asfrom and not is_new_cte:
4123 return None
4124
4125 if cte._cte_alias is not None:
4126 pre_alias_cte = cte._cte_alias
4127 cte_pre_alias_name = cte._cte_alias.name
4128 if isinstance(cte_pre_alias_name, elements._truncated_label):
4129 cte_pre_alias_name = self._truncated_identifier(
4130 "alias", cte_pre_alias_name
4131 )
4132 else:
4133 pre_alias_cte = cte
4134 cte_pre_alias_name = None
4135
4136 if is_new_cte:
4137 self.ctes_by_level_name[cte_level_name] = cte
4138 self.level_name_by_cte[_reference_cte] = cte_level_name + (
4139 cte_opts,
4140 )
4141
4142 if pre_alias_cte not in self.ctes:
4143 self.visit_cte(pre_alias_cte, **kwargs)
4144
4145 if not cte_pre_alias_name and cte not in self_ctes:
4146 if cte.recursive:
4147 self.ctes_recursive = True
4148 text = self.preparer.format_alias(cte, cte_name)
4149 if cte.recursive or cte.element.name_cte_columns:
4150 col_source = cte.element
4151
4152 # TODO: can we get at the .columns_plus_names collection
4153 # that is already (or will be?) generated for the SELECT
4154 # rather than calling twice?
4155 recur_cols = [
4156 # TODO: proxy_name is not technically safe,
4157 # see test_cte->
4158 # test_with_recursive_no_name_currently_buggy. not
4159 # clear what should be done with such a case
4160 fallback_label_name or proxy_name
4161 for (
4162 _,
4163 proxy_name,
4164 fallback_label_name,
4165 c,
4166 repeated,
4167 ) in (col_source._generate_columns_plus_names(True))
4168 if not repeated
4169 ]
4170
4171 text += "(%s)" % (
4172 ", ".join(
4173 self.preparer.format_label_name(
4174 ident, anon_map=self.anon_map
4175 )
4176 for ident in recur_cols
4177 )
4178 )
4179
4180 assert kwargs.get("subquery", False) is False
4181
4182 if not self.stack:
4183 # toplevel, this is a stringify of the
4184 # cte directly. just compile the inner
4185 # the way alias() does.
4186 return cte.element._compiler_dispatch(
4187 self, asfrom=asfrom, **kwargs
4188 )
4189 else:
4190 prefixes = self._generate_prefixes(
4191 cte, cte._prefixes, **kwargs
4192 )
4193 inner = cte.element._compiler_dispatch(
4194 self, asfrom=True, **kwargs
4195 )
4196
4197 text += " AS %s\n(%s)" % (prefixes, inner)
4198
4199 if cte._suffixes:
4200 text += " " + self._generate_prefixes(
4201 cte, cte._suffixes, **kwargs
4202 )
4203
4204 self_ctes[cte] = text
4205
4206 if asfrom:
4207 if from_linter:
4208 from_linter.froms[cte._de_clone()] = cte_name
4209
4210 if not is_new_cte and embedded_in_current_named_cte:
4211 return self.preparer.format_alias(cte, cte_name)
4212
4213 if cte_pre_alias_name:
4214 text = self.preparer.format_alias(cte, cte_pre_alias_name)
4215 if self.preparer._requires_quotes(cte_name):
4216 cte_name = self.preparer.quote(cte_name)
4217 text += self.get_render_as_alias_suffix(cte_name)
4218 return text # type: ignore[no-any-return]
4219 else:
4220 return self.preparer.format_alias(cte, cte_name)
4221
4222 return None
4223
4224 def visit_table_valued_alias(self, element, **kw):
4225 if element.joins_implicitly:
4226 kw["from_linter"] = None
4227 if element._is_lateral:
4228 return self.visit_lateral(element, **kw)
4229 else:
4230 return self.visit_alias(element, **kw)
4231
4232 def visit_table_valued_column(self, element, **kw):
4233 return self.visit_column(element, **kw)
4234
4235 def visit_alias(
4236 self,
4237 alias,
4238 asfrom=False,
4239 ashint=False,
4240 iscrud=False,
4241 fromhints=None,
4242 subquery=False,
4243 lateral=False,
4244 enclosing_alias=None,
4245 from_linter=None,
4246 **kwargs,
4247 ):
4248 if lateral:
4249 if "enclosing_lateral" not in kwargs:
4250 # if lateral is set and enclosing_lateral is not
4251 # present, we assume we are being called directly
4252 # from visit_lateral() and we need to set enclosing_lateral.
4253 assert alias._is_lateral
4254 kwargs["enclosing_lateral"] = alias
4255
4256 # for lateral objects, we track a second from_linter that is...
4257 # lateral! to the level above us.
4258 if (
4259 from_linter
4260 and "lateral_from_linter" not in kwargs
4261 and "enclosing_lateral" in kwargs
4262 ):
4263 kwargs["lateral_from_linter"] = from_linter
4264
4265 if enclosing_alias is not None and enclosing_alias.element is alias:
4266 inner = alias.element._compiler_dispatch(
4267 self,
4268 asfrom=asfrom,
4269 ashint=ashint,
4270 iscrud=iscrud,
4271 fromhints=fromhints,
4272 lateral=lateral,
4273 enclosing_alias=alias,
4274 **kwargs,
4275 )
4276 if subquery and (asfrom or lateral):
4277 inner = "(%s)" % (inner,)
4278 return inner
4279 else:
4280 kwargs["enclosing_alias"] = alias
4281
4282 if asfrom or ashint:
4283 if isinstance(alias.name, elements._truncated_label):
4284 alias_name = self._truncated_identifier("alias", alias.name)
4285 else:
4286 alias_name = alias.name
4287
4288 if ashint:
4289 return self.preparer.format_alias(alias, alias_name)
4290 elif asfrom:
4291 if from_linter:
4292 from_linter.froms[alias._de_clone()] = alias_name
4293
4294 inner = alias.element._compiler_dispatch(
4295 self, asfrom=True, lateral=lateral, **kwargs
4296 )
4297 if subquery:
4298 inner = "(%s)" % (inner,)
4299
4300 ret = inner + self.get_render_as_alias_suffix(
4301 self.preparer.format_alias(alias, alias_name)
4302 )
4303
4304 if alias._supports_derived_columns and alias._render_derived:
4305 ret += "(%s)" % (
4306 ", ".join(
4307 "%s%s"
4308 % (
4309 self.preparer.quote(col.name),
4310 (
4311 " %s"
4312 % self.dialect.type_compiler_instance.process(
4313 col.type, **kwargs
4314 )
4315 if alias._render_derived_w_types
4316 else ""
4317 ),
4318 )
4319 for col in alias.c
4320 )
4321 )
4322
4323 if fromhints and alias in fromhints:
4324 ret = self.format_from_hint_text(
4325 ret, alias, fromhints[alias], iscrud
4326 )
4327
4328 return ret
4329 else:
4330 # note we cancel the "subquery" flag here as well
4331 return alias.element._compiler_dispatch(
4332 self, lateral=lateral, **kwargs
4333 )
4334
4335 def visit_subquery(self, subquery, **kw):
4336 kw["subquery"] = True
4337 return self.visit_alias(subquery, **kw)
4338
4339 def visit_lateral(self, lateral_, **kw):
4340 kw["lateral"] = True
4341 return "LATERAL %s" % self.visit_alias(lateral_, **kw)
4342
4343 def visit_tablesample(self, tablesample, asfrom=False, **kw):
4344 text = "%s TABLESAMPLE %s" % (
4345 self.visit_alias(tablesample, asfrom=True, **kw),
4346 tablesample._get_method()._compiler_dispatch(self, **kw),
4347 )
4348
4349 if tablesample.seed is not None:
4350 text += " REPEATABLE (%s)" % (
4351 tablesample.seed._compiler_dispatch(self, **kw)
4352 )
4353
4354 return text
4355
4356 def _render_values(self, element, **kw):
4357 kw.setdefault("literal_binds", element.literal_binds)
4358 tuples = ", ".join(
4359 self.process(
4360 elements.Tuple(
4361 types=element._column_types, *elem
4362 ).self_group(),
4363 **kw,
4364 )
4365 for chunk in element._data
4366 for elem in chunk
4367 )
4368 return f"VALUES {tuples}"
4369
4370 def visit_values(
4371 self, element, asfrom=False, from_linter=None, visiting_cte=None, **kw
4372 ):
4373
4374 if element._independent_ctes:
4375 self._dispatch_independent_ctes(element, kw)
4376
4377 v = self._render_values(element, **kw)
4378
4379 if element._unnamed:
4380 name = None
4381 elif isinstance(element.name, elements._truncated_label):
4382 name = self._truncated_identifier("values", element.name)
4383 else:
4384 name = element.name
4385
4386 if element._is_lateral:
4387 lateral = "LATERAL "
4388 else:
4389 lateral = ""
4390
4391 if asfrom:
4392 if from_linter:
4393 from_linter.froms[element._de_clone()] = (
4394 name if name is not None else "(unnamed VALUES element)"
4395 )
4396
4397 if visiting_cte is not None and visiting_cte.element is element:
4398 if element._is_lateral:
4399 raise exc.CompileError(
4400 "Can't use a LATERAL VALUES expression inside of a CTE"
4401 )
4402 elif name:
4403 kw["include_table"] = False
4404 v = "%s(%s)%s (%s)" % (
4405 lateral,
4406 v,
4407 self.get_render_as_alias_suffix(self.preparer.quote(name)),
4408 (
4409 ", ".join(
4410 c._compiler_dispatch(self, **kw)
4411 for c in element.columns
4412 )
4413 ),
4414 )
4415 else:
4416 v = "%s(%s)" % (lateral, v)
4417 return v
4418
4419 def visit_scalar_values(self, element, **kw):
4420 return f"({self._render_values(element, **kw)})"
4421
4422 def get_render_as_alias_suffix(self, alias_name_text):
4423 return " AS " + alias_name_text
4424
4425 def _add_to_result_map(
4426 self,
4427 keyname: str,
4428 name: str,
4429 objects: Tuple[Any, ...],
4430 type_: TypeEngine[Any],
4431 ) -> None:
4432
4433 # note objects must be non-empty for cursor.py to handle the
4434 # collection properly
4435 assert objects
4436
4437 if keyname is None or keyname == "*":
4438 self._ordered_columns = False
4439 self._ad_hoc_textual = True
4440 if type_._is_tuple_type:
4441 raise exc.CompileError(
4442 "Most backends don't support SELECTing "
4443 "from a tuple() object. If this is an ORM query, "
4444 "consider using the Bundle object."
4445 )
4446 self._result_columns.append(
4447 ResultColumnsEntry(keyname, name, objects, type_)
4448 )
4449
4450 def _label_returning_column(
4451 self, stmt, column, populate_result_map, column_clause_args=None, **kw
4452 ):
4453 """Render a column with necessary labels inside of a RETURNING clause.
4454
4455 This method is provided for individual dialects in place of calling
4456 the _label_select_column method directly, so that the two use cases
4457 of RETURNING vs. SELECT can be disambiguated going forward.
4458
4459 .. versionadded:: 1.4.21
4460
4461 """
4462 return self._label_select_column(
4463 None,
4464 column,
4465 populate_result_map,
4466 False,
4467 {} if column_clause_args is None else column_clause_args,
4468 **kw,
4469 )
4470
4471 def _label_select_column(
4472 self,
4473 select,
4474 column,
4475 populate_result_map,
4476 asfrom,
4477 column_clause_args,
4478 name=None,
4479 proxy_name=None,
4480 fallback_label_name=None,
4481 within_columns_clause=True,
4482 column_is_repeated=False,
4483 need_column_expressions=False,
4484 include_table=True,
4485 ):
4486 """produce labeled columns present in a select()."""
4487 impl = column.type.dialect_impl(self.dialect)
4488
4489 if impl._has_column_expression and (
4490 need_column_expressions or populate_result_map
4491 ):
4492 col_expr = impl.column_expression(column)
4493 else:
4494 col_expr = column
4495
4496 if populate_result_map:
4497 # pass an "add_to_result_map" callable into the compilation
4498 # of embedded columns. this collects information about the
4499 # column as it will be fetched in the result and is coordinated
4500 # with cursor.description when the query is executed.
4501 add_to_result_map = self._add_to_result_map
4502
4503 # if the SELECT statement told us this column is a repeat,
4504 # wrap the callable with one that prevents the addition of the
4505 # targets
4506 if column_is_repeated:
4507 _add_to_result_map = add_to_result_map
4508
4509 def add_to_result_map(keyname, name, objects, type_):
4510 _add_to_result_map(keyname, name, (keyname,), type_)
4511
4512 # if we redefined col_expr for type expressions, wrap the
4513 # callable with one that adds the original column to the targets
4514 elif col_expr is not column:
4515 _add_to_result_map = add_to_result_map
4516
4517 def add_to_result_map(keyname, name, objects, type_):
4518 _add_to_result_map(
4519 keyname, name, (column,) + objects, type_
4520 )
4521
4522 else:
4523 add_to_result_map = None
4524
4525 # this method is used by some of the dialects for RETURNING,
4526 # which has different inputs. _label_returning_column was added
4527 # as the better target for this now however for 1.4 we will keep
4528 # _label_select_column directly compatible with this use case.
4529 # these assertions right now set up the current expected inputs
4530 assert within_columns_clause, (
4531 "_label_select_column is only relevant within "
4532 "the columns clause of a SELECT or RETURNING"
4533 )
4534 if isinstance(column, elements.Label):
4535 if col_expr is not column:
4536 result_expr = _CompileLabel(
4537 col_expr, column.name, alt_names=(column.element,)
4538 )
4539 else:
4540 result_expr = col_expr
4541
4542 elif name:
4543 # here, _columns_plus_names has determined there's an explicit
4544 # label name we need to use. this is the default for
4545 # tablenames_plus_columnnames as well as when columns are being
4546 # deduplicated on name
4547
4548 assert (
4549 proxy_name is not None
4550 ), "proxy_name is required if 'name' is passed"
4551
4552 result_expr = _CompileLabel(
4553 col_expr,
4554 name,
4555 alt_names=(
4556 proxy_name,
4557 # this is a hack to allow legacy result column lookups
4558 # to work as they did before; this goes away in 2.0.
4559 # TODO: this only seems to be tested indirectly
4560 # via test/orm/test_deprecations.py. should be a
4561 # resultset test for this
4562 column._tq_label,
4563 ),
4564 )
4565 else:
4566 # determine here whether this column should be rendered in
4567 # a labelled context or not, as we were given no required label
4568 # name from the caller. Here we apply heuristics based on the kind
4569 # of SQL expression involved.
4570
4571 if col_expr is not column:
4572 # type-specific expression wrapping the given column,
4573 # so we render a label
4574 render_with_label = True
4575 elif isinstance(column, elements.ColumnClause):
4576 # table-bound column, we render its name as a label if we are
4577 # inside of a subquery only
4578 render_with_label = (
4579 asfrom
4580 and not column.is_literal
4581 and column.table is not None
4582 )
4583 elif isinstance(column, elements.TextClause):
4584 render_with_label = False
4585 elif isinstance(column, elements.UnaryExpression):
4586 # unary expression. notes added as of #12681
4587 #
4588 # By convention, the visit_unary() method
4589 # itself does not add an entry to the result map, and relies
4590 # upon either the inner expression creating a result map
4591 # entry, or if not, by creating a label here that produces
4592 # the result map entry. Where that happens is based on whether
4593 # or not the element immediately inside the unary is a
4594 # NamedColumn subclass or not.
4595 #
4596 # Now, this also impacts how the SELECT is written; if
4597 # we decide to generate a label here, we get the usual
4598 # "~(x+y) AS anon_1" thing in the columns clause. If we
4599 # don't, we don't get an AS at all, we get like
4600 # "~table.column".
4601 #
4602 # But here is the important thing as of modernish (like 1.4)
4603 # versions of SQLAlchemy - **whether or not the AS <label>
4604 # is present in the statement is not actually important**.
4605 # We target result columns **positionally** for a fully
4606 # compiled ``Select()`` object; before 1.4 we needed those
4607 # labels to match in cursor.description etc etc but now it
4608 # really doesn't matter.
4609 # So really, we could set render_with_label True in all cases.
4610 # Or we could just have visit_unary() populate the result map
4611 # in all cases.
4612 #
4613 # What we're doing here is strictly trying to not rock the
4614 # boat too much with when we do/don't render "AS label";
4615 # labels being present helps in the edge cases that we
4616 # "fall back" to named cursor.description matching, labels
4617 # not being present for columns keeps us from having awkward
4618 # phrases like "SELECT DISTINCT table.x AS x".
4619 render_with_label = (
4620 (
4621 # exception case to detect if we render "not boolean"
4622 # as "not <col>" for native boolean or "<col> = 1"
4623 # for non-native boolean. this is controlled by
4624 # visit_is_<true|false>_unary_operator
4625 column.operator
4626 in (operators.is_false, operators.is_true)
4627 and not self.dialect.supports_native_boolean
4628 )
4629 or column._wraps_unnamed_column()
4630 or asfrom
4631 )
4632 elif (
4633 # general class of expressions that don't have a SQL-column
4634 # addressible name. includes scalar selects, bind parameters,
4635 # SQL functions, others
4636 not isinstance(column, elements.NamedColumn)
4637 # deeper check that indicates there's no natural "name" to
4638 # this element, which accommodates for custom SQL constructs
4639 # that might have a ".name" attribute (but aren't SQL
4640 # functions) but are not implementing this more recently added
4641 # base class. in theory the "NamedColumn" check should be
4642 # enough, however here we seek to maintain legacy behaviors
4643 # as well.
4644 and column._non_anon_label is None
4645 ):
4646 render_with_label = True
4647 else:
4648 render_with_label = False
4649
4650 if render_with_label:
4651 if not fallback_label_name:
4652 # used by the RETURNING case right now. we generate it
4653 # here as 3rd party dialects may be referring to
4654 # _label_select_column method directly instead of the
4655 # just-added _label_returning_column method
4656 assert not column_is_repeated
4657 fallback_label_name = column._anon_name_label
4658
4659 fallback_label_name = (
4660 elements._truncated_label(fallback_label_name)
4661 if not isinstance(
4662 fallback_label_name, elements._truncated_label
4663 )
4664 else fallback_label_name
4665 )
4666
4667 result_expr = _CompileLabel(
4668 col_expr, fallback_label_name, alt_names=(proxy_name,)
4669 )
4670 else:
4671 result_expr = col_expr
4672
4673 column_clause_args.update(
4674 within_columns_clause=within_columns_clause,
4675 add_to_result_map=add_to_result_map,
4676 include_table=include_table,
4677 )
4678 return result_expr._compiler_dispatch(self, **column_clause_args)
4679
4680 def format_from_hint_text(self, sqltext, table, hint, iscrud):
4681 hinttext = self.get_from_hint_text(table, hint)
4682 if hinttext:
4683 sqltext += " " + hinttext
4684 return sqltext
4685
4686 def get_select_hint_text(self, byfroms):
4687 return None
4688
4689 def get_from_hint_text(
4690 self, table: FromClause, text: Optional[str]
4691 ) -> Optional[str]:
4692 return None
4693
4694 def get_crud_hint_text(self, table, text):
4695 return None
4696
4697 def get_statement_hint_text(self, hint_texts):
4698 return " ".join(hint_texts)
4699
4700 _default_stack_entry: _CompilerStackEntry
4701
4702 if not typing.TYPE_CHECKING:
4703 _default_stack_entry = util.immutabledict(
4704 [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
4705 )
4706
4707 def _display_froms_for_select(
4708 self, select_stmt, asfrom, lateral=False, **kw
4709 ):
4710 # utility method to help external dialects
4711 # get the correct from list for a select.
4712 # specifically the oracle dialect needs this feature
4713 # right now.
4714 toplevel = not self.stack
4715 entry = self._default_stack_entry if toplevel else self.stack[-1]
4716
4717 compile_state = select_stmt._compile_state_factory(select_stmt, self)
4718
4719 correlate_froms = entry["correlate_froms"]
4720 asfrom_froms = entry["asfrom_froms"]
4721
4722 if asfrom and not lateral:
4723 froms = compile_state._get_display_froms(
4724 explicit_correlate_froms=correlate_froms.difference(
4725 asfrom_froms
4726 ),
4727 implicit_correlate_froms=(),
4728 )
4729 else:
4730 froms = compile_state._get_display_froms(
4731 explicit_correlate_froms=correlate_froms,
4732 implicit_correlate_froms=asfrom_froms,
4733 )
4734 return froms
4735
4736 translate_select_structure: Any = None
4737 """if not ``None``, should be a callable which accepts ``(select_stmt,
4738 **kw)`` and returns a select object. this is used for structural changes
4739 mostly to accommodate for LIMIT/OFFSET schemes
4740
4741 """
4742
4743 def visit_select(
4744 self,
4745 select_stmt,
4746 asfrom=False,
4747 insert_into=False,
4748 fromhints=None,
4749 compound_index=None,
4750 select_wraps_for=None,
4751 lateral=False,
4752 from_linter=None,
4753 **kwargs,
4754 ):
4755 assert select_wraps_for is None, (
4756 "SQLAlchemy 1.4 requires use of "
4757 "the translate_select_structure hook for structural "
4758 "translations of SELECT objects"
4759 )
4760
4761 # initial setup of SELECT. the compile_state_factory may now
4762 # be creating a totally different SELECT from the one that was
4763 # passed in. for ORM use this will convert from an ORM-state
4764 # SELECT to a regular "Core" SELECT. other composed operations
4765 # such as computation of joins will be performed.
4766
4767 kwargs["within_columns_clause"] = False
4768
4769 compile_state = select_stmt._compile_state_factory(
4770 select_stmt, self, **kwargs
4771 )
4772 kwargs["ambiguous_table_name_map"] = (
4773 compile_state._ambiguous_table_name_map
4774 )
4775
4776 select_stmt = compile_state.statement
4777
4778 toplevel = not self.stack
4779
4780 if toplevel and not self.compile_state:
4781 self.compile_state = compile_state
4782
4783 is_embedded_select = compound_index is not None or insert_into
4784
4785 # translate step for Oracle, SQL Server which often need to
4786 # restructure the SELECT to allow for LIMIT/OFFSET and possibly
4787 # other conditions
4788 if self.translate_select_structure:
4789 new_select_stmt = self.translate_select_structure(
4790 select_stmt, asfrom=asfrom, **kwargs
4791 )
4792
4793 # if SELECT was restructured, maintain a link to the originals
4794 # and assemble a new compile state
4795 if new_select_stmt is not select_stmt:
4796 compile_state_wraps_for = compile_state
4797 select_wraps_for = select_stmt
4798 select_stmt = new_select_stmt
4799
4800 compile_state = select_stmt._compile_state_factory(
4801 select_stmt, self, **kwargs
4802 )
4803 select_stmt = compile_state.statement
4804
4805 entry = self._default_stack_entry if toplevel else self.stack[-1]
4806
4807 populate_result_map = need_column_expressions = (
4808 toplevel
4809 or entry.get("need_result_map_for_compound", False)
4810 or entry.get("need_result_map_for_nested", False)
4811 )
4812
4813 # indicates there is a CompoundSelect in play and we are not the
4814 # first select
4815 if compound_index:
4816 populate_result_map = False
4817
4818 # this was first proposed as part of #3372; however, it is not
4819 # reached in current tests and could possibly be an assertion
4820 # instead.
4821 if not populate_result_map and "add_to_result_map" in kwargs:
4822 del kwargs["add_to_result_map"]
4823
4824 froms = self._setup_select_stack(
4825 select_stmt, compile_state, entry, asfrom, lateral, compound_index
4826 )
4827
4828 column_clause_args = kwargs.copy()
4829 column_clause_args.update(
4830 {"within_label_clause": False, "within_columns_clause": False}
4831 )
4832
4833 text = "SELECT " # we're off to a good start !
4834
4835 if select_stmt._post_select_clause is not None:
4836 psc = self.process(select_stmt._post_select_clause, **kwargs)
4837 if psc is not None:
4838 text += psc + " "
4839
4840 if select_stmt._hints:
4841 hint_text, byfrom = self._setup_select_hints(select_stmt)
4842 if hint_text:
4843 text += hint_text + " "
4844 else:
4845 byfrom = None
4846
4847 if select_stmt._independent_ctes:
4848 self._dispatch_independent_ctes(select_stmt, kwargs)
4849
4850 if select_stmt._prefixes:
4851 text += self._generate_prefixes(
4852 select_stmt, select_stmt._prefixes, **kwargs
4853 )
4854
4855 text += self.get_select_precolumns(select_stmt, **kwargs)
4856
4857 if select_stmt._pre_columns_clause is not None:
4858 pcc = self.process(select_stmt._pre_columns_clause, **kwargs)
4859 if pcc is not None:
4860 text += pcc + " "
4861
4862 # the actual list of columns to print in the SELECT column list.
4863 inner_columns = [
4864 c
4865 for c in [
4866 self._label_select_column(
4867 select_stmt,
4868 column,
4869 populate_result_map,
4870 asfrom,
4871 column_clause_args,
4872 name=name,
4873 proxy_name=proxy_name,
4874 fallback_label_name=fallback_label_name,
4875 column_is_repeated=repeated,
4876 need_column_expressions=need_column_expressions,
4877 )
4878 for (
4879 name,
4880 proxy_name,
4881 fallback_label_name,
4882 column,
4883 repeated,
4884 ) in compile_state.columns_plus_names
4885 ]
4886 if c is not None
4887 ]
4888
4889 if populate_result_map and select_wraps_for is not None:
4890 # if this select was generated from translate_select,
4891 # rewrite the targeted columns in the result map
4892
4893 translate = dict(
4894 zip(
4895 [
4896 name
4897 for (
4898 key,
4899 proxy_name,
4900 fallback_label_name,
4901 name,
4902 repeated,
4903 ) in compile_state.columns_plus_names
4904 ],
4905 [
4906 name
4907 for (
4908 key,
4909 proxy_name,
4910 fallback_label_name,
4911 name,
4912 repeated,
4913 ) in compile_state_wraps_for.columns_plus_names
4914 ],
4915 )
4916 )
4917
4918 self._result_columns = [
4919 ResultColumnsEntry(
4920 key, name, tuple(translate.get(o, o) for o in obj), type_
4921 )
4922 for key, name, obj, type_ in self._result_columns
4923 ]
4924
4925 text = self._compose_select_body(
4926 text,
4927 select_stmt,
4928 compile_state,
4929 inner_columns,
4930 froms,
4931 byfrom,
4932 toplevel,
4933 kwargs,
4934 )
4935
4936 if select_stmt._post_body_clause is not None:
4937 pbc = self.process(select_stmt._post_body_clause, **kwargs)
4938 if pbc:
4939 text += " " + pbc
4940
4941 if select_stmt._statement_hints:
4942 per_dialect = [
4943 ht
4944 for (dialect_name, ht) in select_stmt._statement_hints
4945 if dialect_name in ("*", self.dialect.name)
4946 ]
4947 if per_dialect:
4948 text += " " + self.get_statement_hint_text(per_dialect)
4949
4950 # In compound query, CTEs are shared at the compound level
4951 if self.ctes and (not is_embedded_select or toplevel):
4952 nesting_level = len(self.stack) if not toplevel else None
4953 text = self._render_cte_clause(nesting_level=nesting_level) + text
4954
4955 if select_stmt._suffixes:
4956 text += " " + self._generate_prefixes(
4957 select_stmt, select_stmt._suffixes, **kwargs
4958 )
4959
4960 self.stack.pop(-1)
4961
4962 return text
4963
4964 def _setup_select_hints(
4965 self, select: Select[Unpack[TupleAny]]
4966 ) -> Tuple[str, _FromHintsType]:
4967 byfrom = {
4968 from_: hinttext
4969 % {"name": from_._compiler_dispatch(self, ashint=True)}
4970 for (from_, dialect), hinttext in select._hints.items()
4971 if dialect in ("*", self.dialect.name)
4972 }
4973 hint_text = self.get_select_hint_text(byfrom)
4974 return hint_text, byfrom
4975
4976 def _setup_select_stack(
4977 self, select, compile_state, entry, asfrom, lateral, compound_index
4978 ):
4979 correlate_froms = entry["correlate_froms"]
4980 asfrom_froms = entry["asfrom_froms"]
4981
4982 if compound_index == 0:
4983 entry["select_0"] = select
4984 elif compound_index:
4985 select_0 = entry["select_0"]
4986 numcols = len(select_0._all_selected_columns)
4987
4988 if len(compile_state.columns_plus_names) != numcols:
4989 raise exc.CompileError(
4990 "All selectables passed to "
4991 "CompoundSelect must have identical numbers of "
4992 "columns; select #%d has %d columns, select "
4993 "#%d has %d"
4994 % (
4995 1,
4996 numcols,
4997 compound_index + 1,
4998 len(select._all_selected_columns),
4999 )
5000 )
5001
5002 if asfrom and not lateral:
5003 froms = compile_state._get_display_froms(
5004 explicit_correlate_froms=correlate_froms.difference(
5005 asfrom_froms
5006 ),
5007 implicit_correlate_froms=(),
5008 )
5009 else:
5010 froms = compile_state._get_display_froms(
5011 explicit_correlate_froms=correlate_froms,
5012 implicit_correlate_froms=asfrom_froms,
5013 )
5014
5015 new_correlate_froms = set(_from_objects(*froms))
5016 all_correlate_froms = new_correlate_froms.union(correlate_froms)
5017
5018 new_entry: _CompilerStackEntry = {
5019 "asfrom_froms": new_correlate_froms,
5020 "correlate_froms": all_correlate_froms,
5021 "selectable": select,
5022 "compile_state": compile_state,
5023 }
5024 self.stack.append(new_entry)
5025
5026 return froms
5027
5028 def _compose_select_body(
5029 self,
5030 text,
5031 select,
5032 compile_state,
5033 inner_columns,
5034 froms,
5035 byfrom,
5036 toplevel,
5037 kwargs,
5038 ):
5039 text += ", ".join(inner_columns)
5040
5041 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
5042 from_linter = FromLinter({}, set())
5043 warn_linting = self.linting & WARN_LINTING
5044 if toplevel:
5045 self.from_linter = from_linter
5046 else:
5047 from_linter = None
5048 warn_linting = False
5049
5050 # adjust the whitespace for no inner columns, part of #9440,
5051 # so that a no-col SELECT comes out as "SELECT WHERE..." or
5052 # "SELECT FROM ...".
5053 # while it would be better to have built the SELECT starting string
5054 # without trailing whitespace first, then add whitespace only if inner
5055 # cols were present, this breaks compatibility with various custom
5056 # compilation schemes that are currently being tested.
5057 if not inner_columns:
5058 text = text.rstrip()
5059
5060 if froms:
5061 text += " \nFROM "
5062
5063 if select._hints:
5064 text += ", ".join(
5065 [
5066 f._compiler_dispatch(
5067 self,
5068 asfrom=True,
5069 fromhints=byfrom,
5070 from_linter=from_linter,
5071 **kwargs,
5072 )
5073 for f in froms
5074 ]
5075 )
5076 else:
5077 text += ", ".join(
5078 [
5079 f._compiler_dispatch(
5080 self,
5081 asfrom=True,
5082 from_linter=from_linter,
5083 **kwargs,
5084 )
5085 for f in froms
5086 ]
5087 )
5088 else:
5089 text += self.default_from()
5090
5091 if select._where_criteria:
5092 t = self._generate_delimited_and_list(
5093 select._where_criteria, from_linter=from_linter, **kwargs
5094 )
5095 if t:
5096 text += " \nWHERE " + t
5097
5098 if warn_linting:
5099 assert from_linter is not None
5100 from_linter.warn()
5101
5102 if select._group_by_clauses:
5103 text += self.group_by_clause(select, **kwargs)
5104
5105 if select._having_criteria:
5106 t = self._generate_delimited_and_list(
5107 select._having_criteria, **kwargs
5108 )
5109 if t:
5110 text += " \nHAVING " + t
5111
5112 if select._post_criteria_clause is not None:
5113 pcc = self.process(select._post_criteria_clause, **kwargs)
5114 if pcc is not None:
5115 text += " \n" + pcc
5116
5117 if select._order_by_clauses:
5118 text += self.order_by_clause(select, **kwargs)
5119
5120 if select._has_row_limiting_clause:
5121 text += self._row_limit_clause(select, **kwargs)
5122
5123 if select._for_update_arg is not None:
5124 text += self.for_update_clause(select, **kwargs)
5125
5126 return text
5127
5128 def _generate_prefixes(self, stmt, prefixes, **kw):
5129 clause = " ".join(
5130 prefix._compiler_dispatch(self, **kw)
5131 for prefix, dialect_name in prefixes
5132 if dialect_name in (None, "*") or dialect_name == self.dialect.name
5133 )
5134 if clause:
5135 clause += " "
5136 return clause
5137
5138 def _render_cte_clause(
5139 self,
5140 nesting_level=None,
5141 include_following_stack=False,
5142 ):
5143 """
5144 include_following_stack
5145 Also render the nesting CTEs on the next stack. Useful for
5146 SQL structures like UNION or INSERT that can wrap SELECT
5147 statements containing nesting CTEs.
5148 """
5149 if not self.ctes:
5150 return ""
5151
5152 ctes: MutableMapping[CTE, str]
5153
5154 if nesting_level and nesting_level > 1:
5155 ctes = util.OrderedDict()
5156 for cte in list(self.ctes.keys()):
5157 cte_level, cte_name, cte_opts = self.level_name_by_cte[
5158 cte._get_reference_cte()
5159 ]
5160 nesting = cte.nesting or cte_opts.nesting
5161 is_rendered_level = cte_level == nesting_level or (
5162 include_following_stack and cte_level == nesting_level + 1
5163 )
5164 if not (nesting and is_rendered_level):
5165 continue
5166
5167 ctes[cte] = self.ctes[cte]
5168
5169 else:
5170 ctes = self.ctes
5171
5172 if not ctes:
5173 return ""
5174 ctes_recursive = any([cte.recursive for cte in ctes])
5175
5176 cte_text = self.get_cte_preamble(ctes_recursive) + " "
5177 cte_text += ", \n".join([txt for txt in ctes.values()])
5178 cte_text += "\n "
5179
5180 if nesting_level and nesting_level > 1:
5181 for cte in list(ctes.keys()):
5182 cte_level, cte_name, cte_opts = self.level_name_by_cte[
5183 cte._get_reference_cte()
5184 ]
5185 del self.ctes[cte]
5186 del self.ctes_by_level_name[(cte_level, cte_name)]
5187 del self.level_name_by_cte[cte._get_reference_cte()]
5188
5189 return cte_text
5190
5191 def get_cte_preamble(self, recursive):
5192 if recursive:
5193 return "WITH RECURSIVE"
5194 else:
5195 return "WITH"
5196
5197 def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str:
5198 """Called when building a ``SELECT`` statement, position is just
5199 before column list.
5200
5201 """
5202 if select._distinct_on:
5203 util.warn_deprecated(
5204 "DISTINCT ON is currently supported only by the PostgreSQL "
5205 "dialect. Use of DISTINCT ON for other backends is currently "
5206 "silently ignored, however this usage is deprecated, and will "
5207 "raise CompileError in a future release for all backends "
5208 "that do not support this syntax.",
5209 version="1.4",
5210 )
5211 return "DISTINCT " if select._distinct else ""
5212
5213 def group_by_clause(self, select, **kw):
5214 """allow dialects to customize how GROUP BY is rendered."""
5215
5216 group_by = self._generate_delimited_list(
5217 select._group_by_clauses, OPERATORS[operators.comma_op], **kw
5218 )
5219 if group_by:
5220 return " GROUP BY " + group_by
5221 else:
5222 return ""
5223
5224 def order_by_clause(self, select, **kw):
5225 """allow dialects to customize how ORDER BY is rendered."""
5226
5227 order_by = self._generate_delimited_list(
5228 select._order_by_clauses, OPERATORS[operators.comma_op], **kw
5229 )
5230
5231 if order_by:
5232 return " ORDER BY " + order_by
5233 else:
5234 return ""
5235
5236 def for_update_clause(self, select, **kw):
5237 return " FOR UPDATE"
5238
5239 def returning_clause(
5240 self,
5241 stmt: UpdateBase,
5242 returning_cols: Sequence[_ColumnsClauseElement],
5243 *,
5244 populate_result_map: bool,
5245 **kw: Any,
5246 ) -> str:
5247 columns = [
5248 self._label_returning_column(
5249 stmt,
5250 column,
5251 populate_result_map,
5252 fallback_label_name=fallback_label_name,
5253 column_is_repeated=repeated,
5254 name=name,
5255 proxy_name=proxy_name,
5256 **kw,
5257 )
5258 for (
5259 name,
5260 proxy_name,
5261 fallback_label_name,
5262 column,
5263 repeated,
5264 ) in stmt._generate_columns_plus_names(
5265 True, cols=base._select_iterables(returning_cols)
5266 )
5267 ]
5268
5269 return "RETURNING " + ", ".join(columns)
5270
5271 def limit_clause(self, select, **kw):
5272 text = ""
5273 if select._limit_clause is not None:
5274 text += "\n LIMIT " + self.process(select._limit_clause, **kw)
5275 if select._offset_clause is not None:
5276 if select._limit_clause is None:
5277 text += "\n LIMIT -1"
5278 text += " OFFSET " + self.process(select._offset_clause, **kw)
5279 return text
5280
5281 def fetch_clause(
5282 self,
5283 select,
5284 fetch_clause=None,
5285 require_offset=False,
5286 use_literal_execute_for_simple_int=False,
5287 **kw,
5288 ):
5289 if fetch_clause is None:
5290 fetch_clause = select._fetch_clause
5291 fetch_clause_options = select._fetch_clause_options
5292 else:
5293 fetch_clause_options = {"percent": False, "with_ties": False}
5294
5295 text = ""
5296
5297 if select._offset_clause is not None:
5298 offset_clause = select._offset_clause
5299 if (
5300 use_literal_execute_for_simple_int
5301 and select._simple_int_clause(offset_clause)
5302 ):
5303 offset_clause = offset_clause.render_literal_execute()
5304 offset_str = self.process(offset_clause, **kw)
5305 text += "\n OFFSET %s ROWS" % offset_str
5306 elif require_offset:
5307 text += "\n OFFSET 0 ROWS"
5308
5309 if fetch_clause is not None:
5310 if (
5311 use_literal_execute_for_simple_int
5312 and select._simple_int_clause(fetch_clause)
5313 ):
5314 fetch_clause = fetch_clause.render_literal_execute()
5315 text += "\n FETCH FIRST %s%s ROWS %s" % (
5316 self.process(fetch_clause, **kw),
5317 " PERCENT" if fetch_clause_options["percent"] else "",
5318 "WITH TIES" if fetch_clause_options["with_ties"] else "ONLY",
5319 )
5320 return text
5321
5322 def visit_table(
5323 self,
5324 table,
5325 asfrom=False,
5326 iscrud=False,
5327 ashint=False,
5328 fromhints=None,
5329 use_schema=True,
5330 from_linter=None,
5331 ambiguous_table_name_map=None,
5332 enclosing_alias=None,
5333 **kwargs,
5334 ):
5335 if from_linter:
5336 from_linter.froms[table] = table.fullname
5337
5338 if asfrom or ashint:
5339 effective_schema = self.preparer.schema_for_object(table)
5340
5341 if use_schema and effective_schema:
5342 ret = (
5343 self.preparer.quote_schema(effective_schema)
5344 + "."
5345 + self.preparer.quote(table.name)
5346 )
5347 else:
5348 ret = self.preparer.quote(table.name)
5349
5350 if (
5351 (
5352 enclosing_alias is None
5353 or enclosing_alias.element is not table
5354 )
5355 and not effective_schema
5356 and ambiguous_table_name_map
5357 and table.name in ambiguous_table_name_map
5358 ):
5359 anon_name = self._truncated_identifier(
5360 "alias", ambiguous_table_name_map[table.name]
5361 )
5362
5363 ret = ret + self.get_render_as_alias_suffix(
5364 self.preparer.format_alias(None, anon_name)
5365 )
5366
5367 if fromhints and table in fromhints:
5368 ret = self.format_from_hint_text(
5369 ret, table, fromhints[table], iscrud
5370 )
5371 return ret
5372 else:
5373 return ""
5374
5375 def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
5376 if from_linter:
5377 from_linter.edges.update(
5378 itertools.product(
5379 _de_clone(join.left._from_objects),
5380 _de_clone(join.right._from_objects),
5381 )
5382 )
5383
5384 if join.full:
5385 join_type = " FULL OUTER JOIN "
5386 elif join.isouter:
5387 join_type = " LEFT OUTER JOIN "
5388 else:
5389 join_type = " JOIN "
5390 return (
5391 join.left._compiler_dispatch(
5392 self, asfrom=True, from_linter=from_linter, **kwargs
5393 )
5394 + join_type
5395 + join.right._compiler_dispatch(
5396 self, asfrom=True, from_linter=from_linter, **kwargs
5397 )
5398 + " ON "
5399 # TODO: likely need asfrom=True here?
5400 + join.onclause._compiler_dispatch(
5401 self, from_linter=from_linter, **kwargs
5402 )
5403 )
5404
5405 def _setup_crud_hints(self, stmt, table_text):
5406 dialect_hints = {
5407 table: hint_text
5408 for (table, dialect), hint_text in stmt._hints.items()
5409 if dialect in ("*", self.dialect.name)
5410 }
5411 if stmt.table in dialect_hints:
5412 table_text = self.format_from_hint_text(
5413 table_text, stmt.table, dialect_hints[stmt.table], True
5414 )
5415 return dialect_hints, table_text
5416
5417 # within the realm of "insertmanyvalues sentinel columns",
5418 # these lookups match different kinds of Column() configurations
5419 # to specific backend capabilities. they are broken into two
5420 # lookups, one for autoincrement columns and the other for non
5421 # autoincrement columns
5422 _sentinel_col_non_autoinc_lookup = util.immutabledict(
5423 {
5424 _SentinelDefaultCharacterization.CLIENTSIDE: (
5425 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5426 ),
5427 _SentinelDefaultCharacterization.SENTINEL_DEFAULT: (
5428 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5429 ),
5430 _SentinelDefaultCharacterization.NONE: (
5431 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5432 ),
5433 _SentinelDefaultCharacterization.IDENTITY: (
5434 InsertmanyvaluesSentinelOpts.IDENTITY
5435 ),
5436 _SentinelDefaultCharacterization.SEQUENCE: (
5437 InsertmanyvaluesSentinelOpts.SEQUENCE
5438 ),
5439 }
5440 )
5441 _sentinel_col_autoinc_lookup = _sentinel_col_non_autoinc_lookup.union(
5442 {
5443 _SentinelDefaultCharacterization.NONE: (
5444 InsertmanyvaluesSentinelOpts.AUTOINCREMENT
5445 ),
5446 }
5447 )
5448
5449 def _get_sentinel_column_for_table(
5450 self, table: Table
5451 ) -> Optional[Sequence[Column[Any]]]:
5452 """given a :class:`.Table`, return a usable sentinel column or
5453 columns for this dialect if any.
5454
5455 Return None if no sentinel columns could be identified, or raise an
5456 error if a column was marked as a sentinel explicitly but isn't
5457 compatible with this dialect.
5458
5459 """
5460
5461 sentinel_opts = self.dialect.insertmanyvalues_implicit_sentinel
5462 sentinel_characteristics = table._sentinel_column_characteristics
5463
5464 sent_cols = sentinel_characteristics.columns
5465
5466 if sent_cols is None:
5467 return None
5468
5469 if sentinel_characteristics.is_autoinc:
5470 bitmask = self._sentinel_col_autoinc_lookup.get(
5471 sentinel_characteristics.default_characterization, 0
5472 )
5473 else:
5474 bitmask = self._sentinel_col_non_autoinc_lookup.get(
5475 sentinel_characteristics.default_characterization, 0
5476 )
5477
5478 if sentinel_opts & bitmask:
5479 return sent_cols
5480
5481 if sentinel_characteristics.is_explicit:
5482 # a column was explicitly marked as insert_sentinel=True,
5483 # however it is not compatible with this dialect. they should
5484 # not indicate this column as a sentinel if they need to include
5485 # this dialect.
5486
5487 # TODO: do we want non-primary key explicit sentinel cols
5488 # that can gracefully degrade for some backends?
5489 # insert_sentinel="degrade" perhaps. not for the initial release.
5490 # I am hoping people are generally not dealing with this sentinel
5491 # business at all.
5492
5493 # if is_explicit is True, there will be only one sentinel column.
5494
5495 raise exc.InvalidRequestError(
5496 f"Column {sent_cols[0]} can't be explicitly "
5497 "marked as a sentinel column when using the "
5498 f"{self.dialect.name} dialect, as the "
5499 "particular type of default generation on this column is "
5500 "not currently compatible with this dialect's specific "
5501 f"INSERT..RETURNING syntax which can receive the "
5502 "server-generated value in "
5503 "a deterministic way. To remove this error, remove "
5504 "insert_sentinel=True from primary key autoincrement "
5505 "columns; these columns are automatically used as "
5506 "sentinels for supported dialects in any case."
5507 )
5508
5509 return None
5510
5511 def _deliver_insertmanyvalues_batches(
5512 self,
5513 statement: str,
5514 parameters: _DBAPIMultiExecuteParams,
5515 compiled_parameters: List[_MutableCoreSingleExecuteParams],
5516 generic_setinputsizes: Optional[_GenericSetInputSizesType],
5517 batch_size: int,
5518 sort_by_parameter_order: bool,
5519 schema_translate_map: Optional[SchemaTranslateMapType],
5520 ) -> Iterator[_InsertManyValuesBatch]:
5521 imv = self._insertmanyvalues
5522 assert imv is not None
5523
5524 if not imv.sentinel_param_keys:
5525 _sentinel_from_params = None
5526 else:
5527 _sentinel_from_params = operator.itemgetter(
5528 *imv.sentinel_param_keys
5529 )
5530
5531 lenparams = len(parameters)
5532 if imv.is_default_expr and not self.dialect.supports_default_metavalue:
5533 # backend doesn't support
5534 # INSERT INTO table (pk_col) VALUES (DEFAULT), (DEFAULT), ...
5535 # at the moment this is basically SQL Server due to
5536 # not being able to use DEFAULT for identity column
5537 # just yield out that many single statements! still
5538 # faster than a whole connection.execute() call ;)
5539 #
5540 # note we still are taking advantage of the fact that we know
5541 # we are using RETURNING. The generalized approach of fetching
5542 # cursor.lastrowid etc. still goes through the more heavyweight
5543 # "ExecutionContext per statement" system as it isn't usable
5544 # as a generic "RETURNING" approach
5545 use_row_at_a_time = True
5546 downgraded = False
5547 elif not self.dialect.supports_multivalues_insert or (
5548 sort_by_parameter_order
5549 and self._result_columns
5550 and (imv.sentinel_columns is None or imv.includes_upsert_behaviors)
5551 ):
5552 # deterministic order was requested and the compiler could
5553 # not organize sentinel columns for this dialect/statement.
5554 # use row at a time
5555 use_row_at_a_time = True
5556 downgraded = True
5557 else:
5558 use_row_at_a_time = False
5559 downgraded = False
5560
5561 if use_row_at_a_time:
5562 for batchnum, (param, compiled_param) in enumerate(
5563 cast(
5564 "Sequence[Tuple[_DBAPISingleExecuteParams, _MutableCoreSingleExecuteParams]]", # noqa: E501
5565 zip(parameters, compiled_parameters),
5566 ),
5567 1,
5568 ):
5569 yield _InsertManyValuesBatch(
5570 statement,
5571 param,
5572 generic_setinputsizes,
5573 [param],
5574 (
5575 [_sentinel_from_params(compiled_param)]
5576 if _sentinel_from_params
5577 else []
5578 ),
5579 1,
5580 batchnum,
5581 lenparams,
5582 sort_by_parameter_order,
5583 downgraded,
5584 )
5585 return
5586
5587 if schema_translate_map:
5588 rst = functools.partial(
5589 self.preparer._render_schema_translates,
5590 schema_translate_map=schema_translate_map,
5591 )
5592 else:
5593 rst = None
5594
5595 imv_single_values_expr = imv.single_values_expr
5596 if rst:
5597 imv_single_values_expr = rst(imv_single_values_expr)
5598
5599 executemany_values = f"({imv_single_values_expr})"
5600 statement = statement.replace(executemany_values, "__EXECMANY_TOKEN__")
5601
5602 # Use optional insertmanyvalues_max_parameters
5603 # to further shrink the batch size so that there are no more than
5604 # insertmanyvalues_max_parameters params.
5605 # Currently used by SQL Server, which limits statements to 2100 bound
5606 # parameters (actually 2099).
5607 max_params = self.dialect.insertmanyvalues_max_parameters
5608 if max_params:
5609 total_num_of_params = len(self.bind_names)
5610 num_params_per_batch = len(imv.insert_crud_params)
5611 num_params_outside_of_batch = (
5612 total_num_of_params - num_params_per_batch
5613 )
5614 batch_size = min(
5615 batch_size,
5616 (
5617 (max_params - num_params_outside_of_batch)
5618 // num_params_per_batch
5619 ),
5620 )
5621
5622 batches = cast("List[Sequence[Any]]", list(parameters))
5623 compiled_batches = cast(
5624 "List[Sequence[Any]]", list(compiled_parameters)
5625 )
5626
5627 processed_setinputsizes: Optional[_GenericSetInputSizesType] = None
5628 batchnum = 1
5629 total_batches = lenparams // batch_size + (
5630 1 if lenparams % batch_size else 0
5631 )
5632
5633 insert_crud_params = imv.insert_crud_params
5634 assert insert_crud_params is not None
5635
5636 if rst:
5637 insert_crud_params = [
5638 (col, key, rst(expr), st)
5639 for col, key, expr, st in insert_crud_params
5640 ]
5641
5642 escaped_bind_names: Mapping[str, str]
5643 expand_pos_lower_index = expand_pos_upper_index = 0
5644
5645 if not self.positional:
5646 if self.escaped_bind_names:
5647 escaped_bind_names = self.escaped_bind_names
5648 else:
5649 escaped_bind_names = {}
5650
5651 all_keys = set(parameters[0])
5652
5653 def apply_placeholders(keys, formatted):
5654 for key in keys:
5655 key = escaped_bind_names.get(key, key)
5656 formatted = formatted.replace(
5657 self.bindtemplate % {"name": key},
5658 self.bindtemplate
5659 % {"name": f"{key}__EXECMANY_INDEX__"},
5660 )
5661 return formatted
5662
5663 if imv.embed_values_counter:
5664 imv_values_counter = ", _IMV_VALUES_COUNTER"
5665 else:
5666 imv_values_counter = ""
5667 formatted_values_clause = f"""({', '.join(
5668 apply_placeholders(bind_keys, formatted)
5669 for _, _, formatted, bind_keys in insert_crud_params
5670 )}{imv_values_counter})"""
5671
5672 keys_to_replace = all_keys.intersection(
5673 escaped_bind_names.get(key, key)
5674 for _, _, _, bind_keys in insert_crud_params
5675 for key in bind_keys
5676 )
5677 base_parameters = {
5678 key: parameters[0][key]
5679 for key in all_keys.difference(keys_to_replace)
5680 }
5681 executemany_values_w_comma = ""
5682 else:
5683 formatted_values_clause = ""
5684 keys_to_replace = set()
5685 base_parameters = {}
5686
5687 if imv.embed_values_counter:
5688 executemany_values_w_comma = (
5689 f"({imv_single_values_expr}, _IMV_VALUES_COUNTER), "
5690 )
5691 else:
5692 executemany_values_w_comma = f"({imv_single_values_expr}), "
5693
5694 all_names_we_will_expand: Set[str] = set()
5695 for elem in imv.insert_crud_params:
5696 all_names_we_will_expand.update(elem[3])
5697
5698 # get the start and end position in a particular list
5699 # of parameters where we will be doing the "expanding".
5700 # statements can have params on either side or both sides,
5701 # given RETURNING and CTEs
5702 if all_names_we_will_expand:
5703 positiontup = self.positiontup
5704 assert positiontup is not None
5705
5706 all_expand_positions = {
5707 idx
5708 for idx, name in enumerate(positiontup)
5709 if name in all_names_we_will_expand
5710 }
5711 expand_pos_lower_index = min(all_expand_positions)
5712 expand_pos_upper_index = max(all_expand_positions) + 1
5713 assert (
5714 len(all_expand_positions)
5715 == expand_pos_upper_index - expand_pos_lower_index
5716 )
5717
5718 if self._numeric_binds:
5719 escaped = re.escape(self._numeric_binds_identifier_char)
5720 executemany_values_w_comma = re.sub(
5721 rf"{escaped}\d+", "%s", executemany_values_w_comma
5722 )
5723
5724 while batches:
5725 batch = batches[0:batch_size]
5726 compiled_batch = compiled_batches[0:batch_size]
5727
5728 batches[0:batch_size] = []
5729 compiled_batches[0:batch_size] = []
5730
5731 if batches:
5732 current_batch_size = batch_size
5733 else:
5734 current_batch_size = len(batch)
5735
5736 if generic_setinputsizes:
5737 # if setinputsizes is present, expand this collection to
5738 # suit the batch length as well
5739 # currently this will be mssql+pyodbc for internal dialects
5740 processed_setinputsizes = [
5741 (new_key, len_, typ)
5742 for new_key, len_, typ in (
5743 (f"{key}_{index}", len_, typ)
5744 for index in range(current_batch_size)
5745 for key, len_, typ in generic_setinputsizes
5746 )
5747 ]
5748
5749 replaced_parameters: Any
5750 if self.positional:
5751 num_ins_params = imv.num_positional_params_counted
5752
5753 batch_iterator: Iterable[Sequence[Any]]
5754 extra_params_left: Sequence[Any]
5755 extra_params_right: Sequence[Any]
5756
5757 if num_ins_params == len(batch[0]):
5758 extra_params_left = extra_params_right = ()
5759 batch_iterator = batch
5760 else:
5761 extra_params_left = batch[0][:expand_pos_lower_index]
5762 extra_params_right = batch[0][expand_pos_upper_index:]
5763 batch_iterator = (
5764 b[expand_pos_lower_index:expand_pos_upper_index]
5765 for b in batch
5766 )
5767
5768 if imv.embed_values_counter:
5769 expanded_values_string = (
5770 "".join(
5771 executemany_values_w_comma.replace(
5772 "_IMV_VALUES_COUNTER", str(i)
5773 )
5774 for i, _ in enumerate(batch)
5775 )
5776 )[:-2]
5777 else:
5778 expanded_values_string = (
5779 (executemany_values_w_comma * current_batch_size)
5780 )[:-2]
5781
5782 if self._numeric_binds and num_ins_params > 0:
5783 # numeric will always number the parameters inside of
5784 # VALUES (and thus order self.positiontup) to be higher
5785 # than non-VALUES parameters, no matter where in the
5786 # statement those non-VALUES parameters appear (this is
5787 # ensured in _process_numeric by numbering first all
5788 # params that are not in _values_bindparam)
5789 # therefore all extra params are always
5790 # on the left side and numbered lower than the VALUES
5791 # parameters
5792 assert not extra_params_right
5793
5794 start = expand_pos_lower_index + 1
5795 end = num_ins_params * (current_batch_size) + start
5796
5797 # need to format here, since statement may contain
5798 # unescaped %, while values_string contains just (%s, %s)
5799 positions = tuple(
5800 f"{self._numeric_binds_identifier_char}{i}"
5801 for i in range(start, end)
5802 )
5803 expanded_values_string = expanded_values_string % positions
5804
5805 replaced_statement = statement.replace(
5806 "__EXECMANY_TOKEN__", expanded_values_string
5807 )
5808
5809 replaced_parameters = tuple(
5810 itertools.chain.from_iterable(batch_iterator)
5811 )
5812
5813 replaced_parameters = (
5814 extra_params_left
5815 + replaced_parameters
5816 + extra_params_right
5817 )
5818
5819 else:
5820 replaced_values_clauses = []
5821 replaced_parameters = base_parameters.copy()
5822
5823 for i, param in enumerate(batch):
5824 fmv = formatted_values_clause.replace(
5825 "EXECMANY_INDEX__", str(i)
5826 )
5827 if imv.embed_values_counter:
5828 fmv = fmv.replace("_IMV_VALUES_COUNTER", str(i))
5829
5830 replaced_values_clauses.append(fmv)
5831 replaced_parameters.update(
5832 {f"{key}__{i}": param[key] for key in keys_to_replace}
5833 )
5834
5835 replaced_statement = statement.replace(
5836 "__EXECMANY_TOKEN__",
5837 ", ".join(replaced_values_clauses),
5838 )
5839
5840 yield _InsertManyValuesBatch(
5841 replaced_statement,
5842 replaced_parameters,
5843 processed_setinputsizes,
5844 batch,
5845 (
5846 [_sentinel_from_params(cb) for cb in compiled_batch]
5847 if _sentinel_from_params
5848 else []
5849 ),
5850 current_batch_size,
5851 batchnum,
5852 total_batches,
5853 sort_by_parameter_order,
5854 False,
5855 )
5856 batchnum += 1
5857
5858 def visit_insert(
5859 self, insert_stmt, visited_bindparam=None, visiting_cte=None, **kw
5860 ):
5861 compile_state = insert_stmt._compile_state_factory(
5862 insert_stmt, self, **kw
5863 )
5864 insert_stmt = compile_state.statement
5865
5866 if visiting_cte is not None:
5867 kw["visiting_cte"] = visiting_cte
5868 toplevel = False
5869 else:
5870 toplevel = not self.stack
5871
5872 if toplevel:
5873 self.isinsert = True
5874 if not self.dml_compile_state:
5875 self.dml_compile_state = compile_state
5876 if not self.compile_state:
5877 self.compile_state = compile_state
5878
5879 self.stack.append(
5880 {
5881 "correlate_froms": set(),
5882 "asfrom_froms": set(),
5883 "selectable": insert_stmt,
5884 }
5885 )
5886
5887 counted_bindparam = 0
5888
5889 # reset any incoming "visited_bindparam" collection
5890 visited_bindparam = None
5891
5892 # for positional, insertmanyvalues needs to know how many
5893 # bound parameters are in the VALUES sequence; there's no simple
5894 # rule because default expressions etc. can have zero or more
5895 # params inside them. After multiple attempts to figure this out,
5896 # this very simplistic "count after" works and is
5897 # likely the least amount of callcounts, though looks clumsy
5898 if self.positional and visiting_cte is None:
5899 # if we are inside a CTE, don't count parameters
5900 # here since they wont be for insertmanyvalues. keep
5901 # visited_bindparam at None so no counting happens.
5902 # see #9173
5903 visited_bindparam = []
5904
5905 crud_params_struct = crud._get_crud_params(
5906 self,
5907 insert_stmt,
5908 compile_state,
5909 toplevel,
5910 visited_bindparam=visited_bindparam,
5911 **kw,
5912 )
5913
5914 if self.positional and visited_bindparam is not None:
5915 counted_bindparam = len(visited_bindparam)
5916 if self._numeric_binds:
5917 if self._values_bindparam is not None:
5918 self._values_bindparam += visited_bindparam
5919 else:
5920 self._values_bindparam = visited_bindparam
5921
5922 crud_params_single = crud_params_struct.single_params
5923
5924 if (
5925 not crud_params_single
5926 and not self.dialect.supports_default_values
5927 and not self.dialect.supports_default_metavalue
5928 and not self.dialect.supports_empty_insert
5929 ):
5930 raise exc.CompileError(
5931 "The '%s' dialect with current database "
5932 "version settings does not support empty "
5933 "inserts." % self.dialect.name
5934 )
5935
5936 if compile_state._has_multi_parameters:
5937 if not self.dialect.supports_multivalues_insert:
5938 raise exc.CompileError(
5939 "The '%s' dialect with current database "
5940 "version settings does not support "
5941 "in-place multirow inserts." % self.dialect.name
5942 )
5943 elif (
5944 self.implicit_returning or insert_stmt._returning
5945 ) and insert_stmt._sort_by_parameter_order:
5946 raise exc.CompileError(
5947 "RETURNING cannot be determinstically sorted when "
5948 "using an INSERT which includes multi-row values()."
5949 )
5950 crud_params_single = crud_params_struct.single_params
5951 else:
5952 crud_params_single = crud_params_struct.single_params
5953
5954 preparer = self.preparer
5955 supports_default_values = self.dialect.supports_default_values
5956
5957 text = "INSERT "
5958
5959 if insert_stmt._prefixes:
5960 text += self._generate_prefixes(
5961 insert_stmt, insert_stmt._prefixes, **kw
5962 )
5963
5964 text += "INTO "
5965 table_text = preparer.format_table(insert_stmt.table)
5966
5967 if insert_stmt._hints:
5968 _, table_text = self._setup_crud_hints(insert_stmt, table_text)
5969
5970 if insert_stmt._independent_ctes:
5971 self._dispatch_independent_ctes(insert_stmt, kw)
5972
5973 text += table_text
5974
5975 if crud_params_single or not supports_default_values:
5976 text += " (%s)" % ", ".join(
5977 [expr for _, expr, _, _ in crud_params_single]
5978 )
5979
5980 # look for insertmanyvalues attributes that would have been configured
5981 # by crud.py as it scanned through the columns to be part of the
5982 # INSERT
5983 use_insertmanyvalues = crud_params_struct.use_insertmanyvalues
5984 named_sentinel_params: Optional[Sequence[str]] = None
5985 add_sentinel_cols = None
5986 implicit_sentinel = False
5987
5988 returning_cols = self.implicit_returning or insert_stmt._returning
5989 if returning_cols:
5990 add_sentinel_cols = crud_params_struct.use_sentinel_columns
5991 if add_sentinel_cols is not None:
5992 assert use_insertmanyvalues
5993
5994 # search for the sentinel column explicitly present
5995 # in the INSERT columns list, and additionally check that
5996 # this column has a bound parameter name set up that's in the
5997 # parameter list. If both of these cases are present, it means
5998 # we will have a client side value for the sentinel in each
5999 # parameter set.
6000
6001 _params_by_col = {
6002 col: param_names
6003 for col, _, _, param_names in crud_params_single
6004 }
6005 named_sentinel_params = []
6006 for _add_sentinel_col in add_sentinel_cols:
6007 if _add_sentinel_col not in _params_by_col:
6008 named_sentinel_params = None
6009 break
6010 param_name = self._within_exec_param_key_getter(
6011 _add_sentinel_col
6012 )
6013 if param_name not in _params_by_col[_add_sentinel_col]:
6014 named_sentinel_params = None
6015 break
6016 named_sentinel_params.append(param_name)
6017
6018 if named_sentinel_params is None:
6019 # if we are not going to have a client side value for
6020 # the sentinel in the parameter set, that means it's
6021 # an autoincrement, an IDENTITY, or a server-side SQL
6022 # expression like nextval('seqname'). So this is
6023 # an "implicit" sentinel; we will look for it in
6024 # RETURNING
6025 # only, and then sort on it. For this case on PG,
6026 # SQL Server we have to use a special INSERT form
6027 # that guarantees the server side function lines up with
6028 # the entries in the VALUES.
6029 if (
6030 self.dialect.insertmanyvalues_implicit_sentinel
6031 & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT
6032 ):
6033 implicit_sentinel = True
6034 else:
6035 # here, we are not using a sentinel at all
6036 # and we are likely the SQLite dialect.
6037 # The first add_sentinel_col that we have should not
6038 # be marked as "insert_sentinel=True". if it was,
6039 # an error should have been raised in
6040 # _get_sentinel_column_for_table.
6041 assert not add_sentinel_cols[0]._insert_sentinel, (
6042 "sentinel selection rules should have prevented "
6043 "us from getting here for this dialect"
6044 )
6045
6046 # always put the sentinel columns last. even if they are
6047 # in the returning list already, they will be there twice
6048 # then.
6049 returning_cols = list(returning_cols) + list(add_sentinel_cols)
6050
6051 returning_clause = self.returning_clause(
6052 insert_stmt,
6053 returning_cols,
6054 populate_result_map=toplevel,
6055 )
6056
6057 if self.returning_precedes_values:
6058 text += " " + returning_clause
6059
6060 else:
6061 returning_clause = None
6062
6063 if insert_stmt.select is not None:
6064 # placed here by crud.py
6065 select_text = self.process(
6066 self.stack[-1]["insert_from_select"], insert_into=True, **kw
6067 )
6068
6069 if self.ctes and self.dialect.cte_follows_insert:
6070 nesting_level = len(self.stack) if not toplevel else None
6071 text += " %s%s" % (
6072 self._render_cte_clause(
6073 nesting_level=nesting_level,
6074 include_following_stack=True,
6075 ),
6076 select_text,
6077 )
6078 else:
6079 text += " %s" % select_text
6080 elif not crud_params_single and supports_default_values:
6081 text += " DEFAULT VALUES"
6082 if use_insertmanyvalues:
6083 self._insertmanyvalues = _InsertManyValues(
6084 True,
6085 self.dialect.default_metavalue_token,
6086 cast(
6087 "List[crud._CrudParamElementStr]", crud_params_single
6088 ),
6089 counted_bindparam,
6090 sort_by_parameter_order=(
6091 insert_stmt._sort_by_parameter_order
6092 ),
6093 includes_upsert_behaviors=(
6094 insert_stmt._post_values_clause is not None
6095 ),
6096 sentinel_columns=add_sentinel_cols,
6097 num_sentinel_columns=(
6098 len(add_sentinel_cols) if add_sentinel_cols else 0
6099 ),
6100 implicit_sentinel=implicit_sentinel,
6101 )
6102 elif compile_state._has_multi_parameters:
6103 text += " VALUES %s" % (
6104 ", ".join(
6105 "(%s)"
6106 % (", ".join(value for _, _, value, _ in crud_param_set))
6107 for crud_param_set in crud_params_struct.all_multi_params
6108 ),
6109 )
6110 else:
6111 insert_single_values_expr = ", ".join(
6112 [
6113 value
6114 for _, _, value, _ in cast(
6115 "List[crud._CrudParamElementStr]",
6116 crud_params_single,
6117 )
6118 ]
6119 )
6120
6121 if use_insertmanyvalues:
6122 if (
6123 implicit_sentinel
6124 and (
6125 self.dialect.insertmanyvalues_implicit_sentinel
6126 & InsertmanyvaluesSentinelOpts.USE_INSERT_FROM_SELECT
6127 )
6128 # this is checking if we have
6129 # INSERT INTO table (id) VALUES (DEFAULT).
6130 and not (crud_params_struct.is_default_metavalue_only)
6131 ):
6132 # if we have a sentinel column that is server generated,
6133 # then for selected backends render the VALUES list as a
6134 # subquery. This is the orderable form supported by
6135 # PostgreSQL and SQL Server.
6136 embed_sentinel_value = True
6137
6138 render_bind_casts = (
6139 self.dialect.insertmanyvalues_implicit_sentinel
6140 & InsertmanyvaluesSentinelOpts.RENDER_SELECT_COL_CASTS
6141 )
6142
6143 colnames = ", ".join(
6144 f"p{i}" for i, _ in enumerate(crud_params_single)
6145 )
6146
6147 if render_bind_casts:
6148 # render casts for the SELECT list. For PG, we are
6149 # already rendering bind casts in the parameter list,
6150 # selectively for the more "tricky" types like ARRAY.
6151 # however, even for the "easy" types, if the parameter
6152 # is NULL for every entry, PG gives up and says
6153 # "it must be TEXT", which fails for other easy types
6154 # like ints. So we cast on this side too.
6155 colnames_w_cast = ", ".join(
6156 self.render_bind_cast(
6157 col.type,
6158 col.type._unwrapped_dialect_impl(self.dialect),
6159 f"p{i}",
6160 )
6161 for i, (col, *_) in enumerate(crud_params_single)
6162 )
6163 else:
6164 colnames_w_cast = colnames
6165
6166 text += (
6167 f" SELECT {colnames_w_cast} FROM "
6168 f"(VALUES ({insert_single_values_expr})) "
6169 f"AS imp_sen({colnames}, sen_counter) "
6170 "ORDER BY sen_counter"
6171 )
6172 else:
6173 # otherwise, if no sentinel or backend doesn't support
6174 # orderable subquery form, use a plain VALUES list
6175 embed_sentinel_value = False
6176 text += f" VALUES ({insert_single_values_expr})"
6177
6178 self._insertmanyvalues = _InsertManyValues(
6179 is_default_expr=False,
6180 single_values_expr=insert_single_values_expr,
6181 insert_crud_params=cast(
6182 "List[crud._CrudParamElementStr]",
6183 crud_params_single,
6184 ),
6185 num_positional_params_counted=counted_bindparam,
6186 sort_by_parameter_order=(
6187 insert_stmt._sort_by_parameter_order
6188 ),
6189 includes_upsert_behaviors=(
6190 insert_stmt._post_values_clause is not None
6191 ),
6192 sentinel_columns=add_sentinel_cols,
6193 num_sentinel_columns=(
6194 len(add_sentinel_cols) if add_sentinel_cols else 0
6195 ),
6196 sentinel_param_keys=named_sentinel_params,
6197 implicit_sentinel=implicit_sentinel,
6198 embed_values_counter=embed_sentinel_value,
6199 )
6200
6201 else:
6202 text += f" VALUES ({insert_single_values_expr})"
6203
6204 if insert_stmt._post_values_clause is not None:
6205 post_values_clause = self.process(
6206 insert_stmt._post_values_clause, **kw
6207 )
6208 if post_values_clause:
6209 text += " " + post_values_clause
6210
6211 if returning_clause and not self.returning_precedes_values:
6212 text += " " + returning_clause
6213
6214 if self.ctes and not self.dialect.cte_follows_insert:
6215 nesting_level = len(self.stack) if not toplevel else None
6216 text = (
6217 self._render_cte_clause(
6218 nesting_level=nesting_level,
6219 include_following_stack=True,
6220 )
6221 + text
6222 )
6223
6224 self.stack.pop(-1)
6225
6226 return text
6227
6228 def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
6229 """Provide a hook to override the initial table clause
6230 in an UPDATE statement.
6231
6232 MySQL overrides this.
6233
6234 """
6235 kw["asfrom"] = True
6236 return from_table._compiler_dispatch(self, iscrud=True, **kw)
6237
6238 def update_from_clause(
6239 self, update_stmt, from_table, extra_froms, from_hints, **kw
6240 ):
6241 """Provide a hook to override the generation of an
6242 UPDATE..FROM clause.
6243 MySQL and MSSQL override this.
6244 """
6245 raise NotImplementedError(
6246 "This backend does not support multiple-table "
6247 "criteria within UPDATE"
6248 )
6249
6250 def update_post_criteria_clause(
6251 self, update_stmt: Update, **kw: Any
6252 ) -> Optional[str]:
6253 """provide a hook to override generation after the WHERE criteria
6254 in an UPDATE statement
6255
6256 .. versionadded:: 2.1
6257
6258 """
6259 if update_stmt._post_criteria_clause is not None:
6260 return self.process(
6261 update_stmt._post_criteria_clause,
6262 **kw,
6263 )
6264 else:
6265 return None
6266
6267 def delete_post_criteria_clause(
6268 self, delete_stmt: Delete, **kw: Any
6269 ) -> Optional[str]:
6270 """provide a hook to override generation after the WHERE criteria
6271 in a DELETE statement
6272
6273 .. versionadded:: 2.1
6274
6275 """
6276 if delete_stmt._post_criteria_clause is not None:
6277 return self.process(
6278 delete_stmt._post_criteria_clause,
6279 **kw,
6280 )
6281 else:
6282 return None
6283
6284 def visit_update(
6285 self,
6286 update_stmt: Update,
6287 visiting_cte: Optional[CTE] = None,
6288 **kw: Any,
6289 ) -> str:
6290 compile_state = update_stmt._compile_state_factory(
6291 update_stmt, self, **kw
6292 )
6293 if TYPE_CHECKING:
6294 assert isinstance(compile_state, UpdateDMLState)
6295 update_stmt = compile_state.statement # type: ignore[assignment]
6296
6297 if visiting_cte is not None:
6298 kw["visiting_cte"] = visiting_cte
6299 toplevel = False
6300 else:
6301 toplevel = not self.stack
6302
6303 if toplevel:
6304 self.isupdate = True
6305 if not self.dml_compile_state:
6306 self.dml_compile_state = compile_state
6307 if not self.compile_state:
6308 self.compile_state = compile_state
6309
6310 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
6311 from_linter = FromLinter({}, set())
6312 warn_linting = self.linting & WARN_LINTING
6313 if toplevel:
6314 self.from_linter = from_linter
6315 else:
6316 from_linter = None
6317 warn_linting = False
6318
6319 extra_froms = compile_state._extra_froms
6320 is_multitable = bool(extra_froms)
6321
6322 if is_multitable:
6323 # main table might be a JOIN
6324 main_froms = set(_from_objects(update_stmt.table))
6325 render_extra_froms = [
6326 f for f in extra_froms if f not in main_froms
6327 ]
6328 correlate_froms = main_froms.union(extra_froms)
6329 else:
6330 render_extra_froms = []
6331 correlate_froms = {update_stmt.table}
6332
6333 self.stack.append(
6334 {
6335 "correlate_froms": correlate_froms,
6336 "asfrom_froms": correlate_froms,
6337 "selectable": update_stmt,
6338 }
6339 )
6340
6341 text = "UPDATE "
6342
6343 if update_stmt._prefixes:
6344 text += self._generate_prefixes(
6345 update_stmt, update_stmt._prefixes, **kw
6346 )
6347
6348 table_text = self.update_tables_clause(
6349 update_stmt,
6350 update_stmt.table,
6351 render_extra_froms,
6352 from_linter=from_linter,
6353 **kw,
6354 )
6355 crud_params_struct = crud._get_crud_params(
6356 self, update_stmt, compile_state, toplevel, **kw
6357 )
6358 crud_params = crud_params_struct.single_params
6359
6360 if update_stmt._hints:
6361 dialect_hints, table_text = self._setup_crud_hints(
6362 update_stmt, table_text
6363 )
6364 else:
6365 dialect_hints = None
6366
6367 if update_stmt._independent_ctes:
6368 self._dispatch_independent_ctes(update_stmt, kw)
6369
6370 text += table_text
6371
6372 text += " SET "
6373 text += ", ".join(
6374 expr + "=" + value
6375 for _, expr, value, _ in cast(
6376 "List[Tuple[Any, str, str, Any]]", crud_params
6377 )
6378 )
6379
6380 if self.implicit_returning or update_stmt._returning:
6381 if self.returning_precedes_values:
6382 text += " " + self.returning_clause(
6383 update_stmt,
6384 self.implicit_returning or update_stmt._returning,
6385 populate_result_map=toplevel,
6386 )
6387
6388 if extra_froms:
6389 extra_from_text = self.update_from_clause(
6390 update_stmt,
6391 update_stmt.table,
6392 render_extra_froms,
6393 dialect_hints,
6394 from_linter=from_linter,
6395 **kw,
6396 )
6397 if extra_from_text:
6398 text += " " + extra_from_text
6399
6400 if update_stmt._where_criteria:
6401 t = self._generate_delimited_and_list(
6402 update_stmt._where_criteria, from_linter=from_linter, **kw
6403 )
6404 if t:
6405 text += " WHERE " + t
6406
6407 ulc = self.update_post_criteria_clause(
6408 update_stmt, from_linter=from_linter, **kw
6409 )
6410 if ulc:
6411 text += " " + ulc
6412
6413 if (
6414 self.implicit_returning or update_stmt._returning
6415 ) and not self.returning_precedes_values:
6416 text += " " + self.returning_clause(
6417 update_stmt,
6418 self.implicit_returning or update_stmt._returning,
6419 populate_result_map=toplevel,
6420 )
6421
6422 if self.ctes:
6423 nesting_level = len(self.stack) if not toplevel else None
6424 text = self._render_cte_clause(nesting_level=nesting_level) + text
6425
6426 if warn_linting:
6427 assert from_linter is not None
6428 from_linter.warn(stmt_type="UPDATE")
6429
6430 self.stack.pop(-1)
6431
6432 return text # type: ignore[no-any-return]
6433
6434 def delete_extra_from_clause(
6435 self, delete_stmt, from_table, extra_froms, from_hints, **kw
6436 ):
6437 """Provide a hook to override the generation of an
6438 DELETE..FROM clause.
6439
6440 This can be used to implement DELETE..USING for example.
6441
6442 MySQL and MSSQL override this.
6443
6444 """
6445 raise NotImplementedError(
6446 "This backend does not support multiple-table "
6447 "criteria within DELETE"
6448 )
6449
6450 def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
6451 return from_table._compiler_dispatch(
6452 self, asfrom=True, iscrud=True, **kw
6453 )
6454
6455 def visit_delete(self, delete_stmt, visiting_cte=None, **kw):
6456 compile_state = delete_stmt._compile_state_factory(
6457 delete_stmt, self, **kw
6458 )
6459 delete_stmt = compile_state.statement
6460
6461 if visiting_cte is not None:
6462 kw["visiting_cte"] = visiting_cte
6463 toplevel = False
6464 else:
6465 toplevel = not self.stack
6466
6467 if toplevel:
6468 self.isdelete = True
6469 if not self.dml_compile_state:
6470 self.dml_compile_state = compile_state
6471 if not self.compile_state:
6472 self.compile_state = compile_state
6473
6474 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
6475 from_linter = FromLinter({}, set())
6476 warn_linting = self.linting & WARN_LINTING
6477 if toplevel:
6478 self.from_linter = from_linter
6479 else:
6480 from_linter = None
6481 warn_linting = False
6482
6483 extra_froms = compile_state._extra_froms
6484
6485 correlate_froms = {delete_stmt.table}.union(extra_froms)
6486 self.stack.append(
6487 {
6488 "correlate_froms": correlate_froms,
6489 "asfrom_froms": correlate_froms,
6490 "selectable": delete_stmt,
6491 }
6492 )
6493
6494 text = "DELETE "
6495
6496 if delete_stmt._prefixes:
6497 text += self._generate_prefixes(
6498 delete_stmt, delete_stmt._prefixes, **kw
6499 )
6500
6501 text += "FROM "
6502
6503 try:
6504 table_text = self.delete_table_clause(
6505 delete_stmt,
6506 delete_stmt.table,
6507 extra_froms,
6508 from_linter=from_linter,
6509 )
6510 except TypeError:
6511 # anticipate 3rd party dialects that don't include **kw
6512 # TODO: remove in 2.1
6513 table_text = self.delete_table_clause(
6514 delete_stmt, delete_stmt.table, extra_froms
6515 )
6516 if from_linter:
6517 _ = self.process(delete_stmt.table, from_linter=from_linter)
6518
6519 crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
6520
6521 if delete_stmt._hints:
6522 dialect_hints, table_text = self._setup_crud_hints(
6523 delete_stmt, table_text
6524 )
6525 else:
6526 dialect_hints = None
6527
6528 if delete_stmt._independent_ctes:
6529 self._dispatch_independent_ctes(delete_stmt, kw)
6530
6531 text += table_text
6532
6533 if (
6534 self.implicit_returning or delete_stmt._returning
6535 ) and self.returning_precedes_values:
6536 text += " " + self.returning_clause(
6537 delete_stmt,
6538 self.implicit_returning or delete_stmt._returning,
6539 populate_result_map=toplevel,
6540 )
6541
6542 if extra_froms:
6543 extra_from_text = self.delete_extra_from_clause(
6544 delete_stmt,
6545 delete_stmt.table,
6546 extra_froms,
6547 dialect_hints,
6548 from_linter=from_linter,
6549 **kw,
6550 )
6551 if extra_from_text:
6552 text += " " + extra_from_text
6553
6554 if delete_stmt._where_criteria:
6555 t = self._generate_delimited_and_list(
6556 delete_stmt._where_criteria, from_linter=from_linter, **kw
6557 )
6558 if t:
6559 text += " WHERE " + t
6560
6561 dlc = self.delete_post_criteria_clause(
6562 delete_stmt, from_linter=from_linter, **kw
6563 )
6564 if dlc:
6565 text += " " + dlc
6566
6567 if (
6568 self.implicit_returning or delete_stmt._returning
6569 ) and not self.returning_precedes_values:
6570 text += " " + self.returning_clause(
6571 delete_stmt,
6572 self.implicit_returning or delete_stmt._returning,
6573 populate_result_map=toplevel,
6574 )
6575
6576 if self.ctes:
6577 nesting_level = len(self.stack) if not toplevel else None
6578 text = self._render_cte_clause(nesting_level=nesting_level) + text
6579
6580 if warn_linting:
6581 assert from_linter is not None
6582 from_linter.warn(stmt_type="DELETE")
6583
6584 self.stack.pop(-1)
6585
6586 return text
6587
6588 def visit_savepoint(self, savepoint_stmt, **kw):
6589 return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
6590
6591 def visit_rollback_to_savepoint(self, savepoint_stmt, **kw):
6592 return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
6593 savepoint_stmt
6594 )
6595
6596 def visit_release_savepoint(self, savepoint_stmt, **kw):
6597 return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
6598 savepoint_stmt
6599 )
6600
6601
6602class StrSQLCompiler(SQLCompiler):
6603 """A :class:`.SQLCompiler` subclass which allows a small selection
6604 of non-standard SQL features to render into a string value.
6605
6606 The :class:`.StrSQLCompiler` is invoked whenever a Core expression
6607 element is directly stringified without calling upon the
6608 :meth:`_expression.ClauseElement.compile` method.
6609 It can render a limited set
6610 of non-standard SQL constructs to assist in basic stringification,
6611 however for more substantial custom or dialect-specific SQL constructs,
6612 it will be necessary to make use of
6613 :meth:`_expression.ClauseElement.compile`
6614 directly.
6615
6616 .. seealso::
6617
6618 :ref:`faq_sql_expression_string`
6619
6620 """
6621
6622 def _fallback_column_name(self, column):
6623 return "<name unknown>"
6624
6625 @util.preload_module("sqlalchemy.engine.url")
6626 def visit_unsupported_compilation(self, element, err, **kw):
6627 if element.stringify_dialect != "default":
6628 url = util.preloaded.engine_url
6629 dialect = url.URL.create(element.stringify_dialect).get_dialect()()
6630
6631 compiler = dialect.statement_compiler(
6632 dialect, None, _supporting_against=self
6633 )
6634 if not isinstance(compiler, StrSQLCompiler):
6635 return compiler.process(element, **kw)
6636
6637 return super().visit_unsupported_compilation(element, err)
6638
6639 def visit_getitem_binary(self, binary, operator, **kw):
6640 return "%s[%s]" % (
6641 self.process(binary.left, **kw),
6642 self.process(binary.right, **kw),
6643 )
6644
6645 def visit_json_getitem_op_binary(self, binary, operator, **kw):
6646 return self.visit_getitem_binary(binary, operator, **kw)
6647
6648 def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
6649 return self.visit_getitem_binary(binary, operator, **kw)
6650
6651 def visit_sequence(self, sequence, **kw):
6652 return (
6653 f"<next sequence value: {self.preparer.format_sequence(sequence)}>"
6654 )
6655
6656 def returning_clause(
6657 self,
6658 stmt: UpdateBase,
6659 returning_cols: Sequence[_ColumnsClauseElement],
6660 *,
6661 populate_result_map: bool,
6662 **kw: Any,
6663 ) -> str:
6664 columns = [
6665 self._label_select_column(None, c, True, False, {})
6666 for c in base._select_iterables(returning_cols)
6667 ]
6668 return "RETURNING " + ", ".join(columns)
6669
6670 def update_from_clause(
6671 self, update_stmt, from_table, extra_froms, from_hints, **kw
6672 ):
6673 kw["asfrom"] = True
6674 return "FROM " + ", ".join(
6675 t._compiler_dispatch(self, fromhints=from_hints, **kw)
6676 for t in extra_froms
6677 )
6678
6679 def delete_extra_from_clause(
6680 self, delete_stmt, from_table, extra_froms, from_hints, **kw
6681 ):
6682 kw["asfrom"] = True
6683 return ", " + ", ".join(
6684 t._compiler_dispatch(self, fromhints=from_hints, **kw)
6685 for t in extra_froms
6686 )
6687
6688 def visit_empty_set_expr(self, element_types, **kw):
6689 return "SELECT 1 WHERE 1!=1"
6690
6691 def get_from_hint_text(self, table, text):
6692 return "[%s]" % text
6693
6694 def visit_regexp_match_op_binary(self, binary, operator, **kw):
6695 return self._generate_generic_binary(binary, " <regexp> ", **kw)
6696
6697 def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
6698 return self._generate_generic_binary(binary, " <not regexp> ", **kw)
6699
6700 def visit_regexp_replace_op_binary(self, binary, operator, **kw):
6701 return "<regexp replace>(%s, %s)" % (
6702 binary.left._compiler_dispatch(self, **kw),
6703 binary.right._compiler_dispatch(self, **kw),
6704 )
6705
6706 def visit_try_cast(self, cast, **kwargs):
6707 return "TRY_CAST(%s AS %s)" % (
6708 cast.clause._compiler_dispatch(self, **kwargs),
6709 cast.typeclause._compiler_dispatch(self, **kwargs),
6710 )
6711
6712
6713class DDLCompiler(Compiled):
6714 is_ddl = True
6715
6716 if TYPE_CHECKING:
6717
6718 def __init__(
6719 self,
6720 dialect: Dialect,
6721 statement: ExecutableDDLElement,
6722 schema_translate_map: Optional[SchemaTranslateMapType] = ...,
6723 render_schema_translate: bool = ...,
6724 compile_kwargs: Mapping[str, Any] = ...,
6725 ): ...
6726
6727 @util.ro_memoized_property
6728 def sql_compiler(self) -> SQLCompiler:
6729 return self.dialect.statement_compiler(
6730 self.dialect, None, schema_translate_map=self.schema_translate_map
6731 )
6732
6733 @util.memoized_property
6734 def type_compiler(self):
6735 return self.dialect.type_compiler_instance
6736
6737 def construct_params(
6738 self,
6739 params: Optional[_CoreSingleExecuteParams] = None,
6740 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
6741 escape_names: bool = True,
6742 ) -> Optional[_MutableCoreSingleExecuteParams]:
6743 return None
6744
6745 def visit_ddl(self, ddl, **kwargs):
6746 # table events can substitute table and schema name
6747 context = ddl.context
6748 if isinstance(ddl.target, schema.Table):
6749 context = context.copy()
6750
6751 preparer = self.preparer
6752 path = preparer.format_table_seq(ddl.target)
6753 if len(path) == 1:
6754 table, sch = path[0], ""
6755 else:
6756 table, sch = path[-1], path[0]
6757
6758 context.setdefault("table", table)
6759 context.setdefault("schema", sch)
6760 context.setdefault("fullname", preparer.format_table(ddl.target))
6761
6762 return self.sql_compiler.post_process_text(ddl.statement % context)
6763
6764 def visit_create_schema(self, create, **kw):
6765 text = "CREATE SCHEMA "
6766 if create.if_not_exists:
6767 text += "IF NOT EXISTS "
6768 return text + self.preparer.format_schema(create.element)
6769
6770 def visit_drop_schema(self, drop, **kw):
6771 text = "DROP SCHEMA "
6772 if drop.if_exists:
6773 text += "IF EXISTS "
6774 text += self.preparer.format_schema(drop.element)
6775 if drop.cascade:
6776 text += " CASCADE"
6777 return text
6778
6779 def visit_create_table(self, create, **kw):
6780 table = create.element
6781 preparer = self.preparer
6782
6783 text = "\nCREATE "
6784 if table._prefixes:
6785 text += " ".join(table._prefixes) + " "
6786
6787 text += "TABLE "
6788 if create.if_not_exists:
6789 text += "IF NOT EXISTS "
6790
6791 text += preparer.format_table(table) + " "
6792
6793 create_table_suffix = self.create_table_suffix(table)
6794 if create_table_suffix:
6795 text += create_table_suffix + " "
6796
6797 text += "("
6798
6799 separator = "\n"
6800
6801 # if only one primary key, specify it along with the column
6802 first_pk = False
6803 for create_column in create.columns:
6804 column = create_column.element
6805 try:
6806 processed = self.process(
6807 create_column, first_pk=column.primary_key and not first_pk
6808 )
6809 if processed is not None:
6810 text += separator
6811 separator = ", \n"
6812 text += "\t" + processed
6813 if column.primary_key:
6814 first_pk = True
6815 except exc.CompileError as ce:
6816 raise exc.CompileError(
6817 "(in table '%s', column '%s'): %s"
6818 % (table.description, column.name, ce.args[0])
6819 ) from ce
6820
6821 const = self.create_table_constraints(
6822 table,
6823 _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
6824 )
6825 if const:
6826 text += separator + "\t" + const
6827
6828 text += "\n)%s\n\n" % self.post_create_table(table)
6829 return text
6830
6831 def visit_create_column(self, create, first_pk=False, **kw):
6832 column = create.element
6833
6834 if column.system:
6835 return None
6836
6837 text = self.get_column_specification(column, first_pk=first_pk)
6838 const = " ".join(
6839 self.process(constraint) for constraint in column.constraints
6840 )
6841 if const:
6842 text += " " + const
6843
6844 return text
6845
6846 def create_table_constraints(
6847 self, table, _include_foreign_key_constraints=None, **kw
6848 ):
6849 # On some DB order is significant: visit PK first, then the
6850 # other constraints (engine.ReflectionTest.testbasic failed on FB2)
6851 constraints = []
6852 if table.primary_key:
6853 constraints.append(table.primary_key)
6854
6855 all_fkcs = table.foreign_key_constraints
6856 if _include_foreign_key_constraints is not None:
6857 omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints)
6858 else:
6859 omit_fkcs = set()
6860
6861 constraints.extend(
6862 [
6863 c
6864 for c in table._sorted_constraints
6865 if c is not table.primary_key and c not in omit_fkcs
6866 ]
6867 )
6868
6869 return ", \n\t".join(
6870 p
6871 for p in (
6872 self.process(constraint)
6873 for constraint in constraints
6874 if (constraint._should_create_for_compiler(self))
6875 and (
6876 not self.dialect.supports_alter
6877 or not getattr(constraint, "use_alter", False)
6878 )
6879 )
6880 if p is not None
6881 )
6882
6883 def visit_drop_table(self, drop, **kw):
6884 text = "\nDROP TABLE "
6885 if drop.if_exists:
6886 text += "IF EXISTS "
6887 return text + self.preparer.format_table(drop.element)
6888
6889 def visit_drop_view(self, drop, **kw):
6890 return "\nDROP VIEW " + self.preparer.format_table(drop.element)
6891
6892 def _verify_index_table(self, index: Index) -> None:
6893 if index.table is None:
6894 raise exc.CompileError(
6895 "Index '%s' is not associated with any table." % index.name
6896 )
6897
6898 def visit_create_index(
6899 self, create, include_schema=False, include_table_schema=True, **kw
6900 ):
6901 index = create.element
6902 self._verify_index_table(index)
6903 preparer = self.preparer
6904 text = "CREATE "
6905 if index.unique:
6906 text += "UNIQUE "
6907 if index.name is None:
6908 raise exc.CompileError(
6909 "CREATE INDEX requires that the index have a name"
6910 )
6911
6912 text += "INDEX "
6913 if create.if_not_exists:
6914 text += "IF NOT EXISTS "
6915
6916 text += "%s ON %s (%s)" % (
6917 self._prepared_index_name(index, include_schema=include_schema),
6918 preparer.format_table(
6919 index.table, use_schema=include_table_schema
6920 ),
6921 ", ".join(
6922 self.sql_compiler.process(
6923 expr, include_table=False, literal_binds=True
6924 )
6925 for expr in index.expressions
6926 ),
6927 )
6928 return text
6929
6930 def visit_drop_index(self, drop, **kw):
6931 index = drop.element
6932
6933 if index.name is None:
6934 raise exc.CompileError(
6935 "DROP INDEX requires that the index have a name"
6936 )
6937 text = "\nDROP INDEX "
6938 if drop.if_exists:
6939 text += "IF EXISTS "
6940
6941 return text + self._prepared_index_name(index, include_schema=True)
6942
6943 def _prepared_index_name(
6944 self, index: Index, include_schema: bool = False
6945 ) -> str:
6946 if index.table is not None:
6947 effective_schema = self.preparer.schema_for_object(index.table)
6948 else:
6949 effective_schema = None
6950 if include_schema and effective_schema:
6951 schema_name = self.preparer.quote_schema(effective_schema)
6952 else:
6953 schema_name = None
6954
6955 index_name: str = self.preparer.format_index(index)
6956
6957 if schema_name:
6958 index_name = schema_name + "." + index_name
6959 return index_name
6960
6961 def visit_add_constraint(self, create, **kw):
6962 return "ALTER TABLE %s ADD %s" % (
6963 self.preparer.format_table(create.element.table),
6964 self.process(create.element),
6965 )
6966
6967 def visit_set_table_comment(self, create, **kw):
6968 return "COMMENT ON TABLE %s IS %s" % (
6969 self.preparer.format_table(create.element),
6970 self.sql_compiler.render_literal_value(
6971 create.element.comment, sqltypes.String()
6972 ),
6973 )
6974
6975 def visit_drop_table_comment(self, drop, **kw):
6976 return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
6977 drop.element
6978 )
6979
6980 def visit_set_column_comment(self, create, **kw):
6981 return "COMMENT ON COLUMN %s IS %s" % (
6982 self.preparer.format_column(
6983 create.element, use_table=True, use_schema=True
6984 ),
6985 self.sql_compiler.render_literal_value(
6986 create.element.comment, sqltypes.String()
6987 ),
6988 )
6989
6990 def visit_drop_column_comment(self, drop, **kw):
6991 return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
6992 drop.element, use_table=True
6993 )
6994
6995 def visit_set_constraint_comment(self, create, **kw):
6996 raise exc.UnsupportedCompilationError(self, type(create))
6997
6998 def visit_drop_constraint_comment(self, drop, **kw):
6999 raise exc.UnsupportedCompilationError(self, type(drop))
7000
7001 def get_identity_options(self, identity_options):
7002 text = []
7003 if identity_options.increment is not None:
7004 text.append("INCREMENT BY %d" % identity_options.increment)
7005 if identity_options.start is not None:
7006 text.append("START WITH %d" % identity_options.start)
7007 if identity_options.minvalue is not None:
7008 text.append("MINVALUE %d" % identity_options.minvalue)
7009 if identity_options.maxvalue is not None:
7010 text.append("MAXVALUE %d" % identity_options.maxvalue)
7011 if identity_options.nominvalue is not None:
7012 text.append("NO MINVALUE")
7013 if identity_options.nomaxvalue is not None:
7014 text.append("NO MAXVALUE")
7015 if identity_options.cache is not None:
7016 text.append("CACHE %d" % identity_options.cache)
7017 if identity_options.cycle is not None:
7018 text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
7019 return " ".join(text)
7020
7021 def visit_create_sequence(self, create, prefix=None, **kw):
7022 text = "CREATE SEQUENCE "
7023 if create.if_not_exists:
7024 text += "IF NOT EXISTS "
7025 text += self.preparer.format_sequence(create.element)
7026
7027 if prefix:
7028 text += prefix
7029 options = self.get_identity_options(create.element)
7030 if options:
7031 text += " " + options
7032 return text
7033
7034 def visit_drop_sequence(self, drop, **kw):
7035 text = "DROP SEQUENCE "
7036 if drop.if_exists:
7037 text += "IF EXISTS "
7038 return text + self.preparer.format_sequence(drop.element)
7039
7040 def visit_drop_constraint(self, drop, **kw):
7041 constraint = drop.element
7042 if constraint.name is not None:
7043 formatted_name = self.preparer.format_constraint(constraint)
7044 else:
7045 formatted_name = None
7046
7047 if formatted_name is None:
7048 raise exc.CompileError(
7049 "Can't emit DROP CONSTRAINT for constraint %r; "
7050 "it has no name" % drop.element
7051 )
7052 return "ALTER TABLE %s DROP CONSTRAINT %s%s%s" % (
7053 self.preparer.format_table(drop.element.table),
7054 "IF EXISTS " if drop.if_exists else "",
7055 formatted_name,
7056 " CASCADE" if drop.cascade else "",
7057 )
7058
7059 def get_column_specification(self, column, **kwargs):
7060 colspec = (
7061 self.preparer.format_column(column)
7062 + " "
7063 + self.dialect.type_compiler_instance.process(
7064 column.type, type_expression=column
7065 )
7066 )
7067 default = self.get_column_default_string(column)
7068 if default is not None:
7069 colspec += " DEFAULT " + default
7070
7071 if column.computed is not None:
7072 colspec += " " + self.process(column.computed)
7073
7074 if (
7075 column.identity is not None
7076 and self.dialect.supports_identity_columns
7077 ):
7078 colspec += " " + self.process(column.identity)
7079
7080 if not column.nullable and (
7081 not column.identity or not self.dialect.supports_identity_columns
7082 ):
7083 colspec += " NOT NULL"
7084 return colspec
7085
7086 def create_table_suffix(self, table):
7087 return ""
7088
7089 def post_create_table(self, table):
7090 return ""
7091
7092 def get_column_default_string(self, column: Column[Any]) -> Optional[str]:
7093 if isinstance(column.server_default, schema.DefaultClause):
7094 return self.render_default_string(column.server_default.arg)
7095 else:
7096 return None
7097
7098 def render_default_string(self, default: Union[Visitable, str]) -> str:
7099 if isinstance(default, str):
7100 return self.sql_compiler.render_literal_value(
7101 default, sqltypes.STRINGTYPE
7102 )
7103 else:
7104 return self.sql_compiler.process(default, literal_binds=True)
7105
7106 def visit_table_or_column_check_constraint(self, constraint, **kw):
7107 if constraint.is_column_level:
7108 return self.visit_column_check_constraint(constraint)
7109 else:
7110 return self.visit_check_constraint(constraint)
7111
7112 def visit_check_constraint(self, constraint, **kw):
7113 text = ""
7114 if constraint.name is not None:
7115 formatted_name = self.preparer.format_constraint(constraint)
7116 if formatted_name is not None:
7117 text += "CONSTRAINT %s " % formatted_name
7118 text += "CHECK (%s)" % self.sql_compiler.process(
7119 constraint.sqltext, include_table=False, literal_binds=True
7120 )
7121 text += self.define_constraint_deferrability(constraint)
7122 return text
7123
7124 def visit_column_check_constraint(self, constraint, **kw):
7125 text = ""
7126 if constraint.name is not None:
7127 formatted_name = self.preparer.format_constraint(constraint)
7128 if formatted_name is not None:
7129 text += "CONSTRAINT %s " % formatted_name
7130 text += "CHECK (%s)" % self.sql_compiler.process(
7131 constraint.sqltext, include_table=False, literal_binds=True
7132 )
7133 text += self.define_constraint_deferrability(constraint)
7134 return text
7135
7136 def visit_primary_key_constraint(
7137 self, constraint: PrimaryKeyConstraint, **kw: Any
7138 ) -> str:
7139 if len(constraint) == 0:
7140 return ""
7141 text = ""
7142 if constraint.name is not None:
7143 formatted_name = self.preparer.format_constraint(constraint)
7144 if formatted_name is not None:
7145 text += "CONSTRAINT %s " % formatted_name
7146 text += "PRIMARY KEY "
7147 text += "(%s)" % ", ".join(
7148 self.preparer.quote(c.name)
7149 for c in (
7150 constraint.columns_autoinc_first
7151 if constraint._implicit_generated
7152 else constraint.columns
7153 )
7154 )
7155 text += self.define_constraint_deferrability(constraint)
7156 return text
7157
7158 def visit_foreign_key_constraint(self, constraint, **kw):
7159 preparer = self.preparer
7160 text = ""
7161 if constraint.name is not None:
7162 formatted_name = self.preparer.format_constraint(constraint)
7163 if formatted_name is not None:
7164 text += "CONSTRAINT %s " % formatted_name
7165 remote_table = list(constraint.elements)[0].column.table
7166 text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
7167 ", ".join(
7168 preparer.quote(f.parent.name) for f in constraint.elements
7169 ),
7170 self.define_constraint_remote_table(
7171 constraint, remote_table, preparer
7172 ),
7173 ", ".join(
7174 preparer.quote(f.column.name) for f in constraint.elements
7175 ),
7176 )
7177 text += self.define_constraint_match(constraint)
7178 text += self.define_constraint_cascades(constraint)
7179 text += self.define_constraint_deferrability(constraint)
7180 return text
7181
7182 def define_constraint_remote_table(self, constraint, table, preparer):
7183 """Format the remote table clause of a CREATE CONSTRAINT clause."""
7184
7185 return preparer.format_table(table)
7186
7187 def visit_unique_constraint(
7188 self, constraint: UniqueConstraint, **kw: Any
7189 ) -> str:
7190 if len(constraint) == 0:
7191 return ""
7192 text = ""
7193 if constraint.name is not None:
7194 formatted_name = self.preparer.format_constraint(constraint)
7195 if formatted_name is not None:
7196 text += "CONSTRAINT %s " % formatted_name
7197 text += "UNIQUE %s(%s)" % (
7198 self.define_unique_constraint_distinct(constraint, **kw),
7199 ", ".join(self.preparer.quote(c.name) for c in constraint),
7200 )
7201 text += self.define_constraint_deferrability(constraint)
7202 return text
7203
7204 def define_unique_constraint_distinct(
7205 self, constraint: UniqueConstraint, **kw: Any
7206 ) -> str:
7207 return ""
7208
7209 def define_constraint_cascades(
7210 self, constraint: ForeignKeyConstraint
7211 ) -> str:
7212 text = ""
7213 if constraint.ondelete is not None:
7214 text += self.define_constraint_ondelete_cascade(constraint)
7215
7216 if constraint.onupdate is not None:
7217 text += self.define_constraint_onupdate_cascade(constraint)
7218 return text
7219
7220 def define_constraint_ondelete_cascade(
7221 self, constraint: ForeignKeyConstraint
7222 ) -> str:
7223 return " ON DELETE %s" % self.preparer.validate_sql_phrase(
7224 constraint.ondelete, FK_ON_DELETE
7225 )
7226
7227 def define_constraint_onupdate_cascade(
7228 self, constraint: ForeignKeyConstraint
7229 ) -> str:
7230 return " ON UPDATE %s" % self.preparer.validate_sql_phrase(
7231 constraint.onupdate, FK_ON_UPDATE
7232 )
7233
7234 def define_constraint_deferrability(self, constraint: Constraint) -> str:
7235 text = ""
7236 if constraint.deferrable is not None:
7237 if constraint.deferrable:
7238 text += " DEFERRABLE"
7239 else:
7240 text += " NOT DEFERRABLE"
7241 if constraint.initially is not None:
7242 text += " INITIALLY %s" % self.preparer.validate_sql_phrase(
7243 constraint.initially, FK_INITIALLY
7244 )
7245 return text
7246
7247 def define_constraint_match(self, constraint):
7248 text = ""
7249 if constraint.match is not None:
7250 text += " MATCH %s" % constraint.match
7251 return text
7252
7253 def visit_computed_column(self, generated, **kw):
7254 text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
7255 generated.sqltext, include_table=False, literal_binds=True
7256 )
7257 if generated.persisted is True:
7258 text += " STORED"
7259 elif generated.persisted is False:
7260 text += " VIRTUAL"
7261 return text
7262
7263 def visit_identity_column(self, identity, **kw):
7264 text = "GENERATED %s AS IDENTITY" % (
7265 "ALWAYS" if identity.always else "BY DEFAULT",
7266 )
7267 options = self.get_identity_options(identity)
7268 if options:
7269 text += " (%s)" % options
7270 return text
7271
7272
7273class GenericTypeCompiler(TypeCompiler):
7274 def visit_FLOAT(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
7275 return "FLOAT"
7276
7277 def visit_DOUBLE(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
7278 return "DOUBLE"
7279
7280 def visit_DOUBLE_PRECISION(
7281 self, type_: sqltypes.DOUBLE_PRECISION[Any], **kw: Any
7282 ) -> str:
7283 return "DOUBLE PRECISION"
7284
7285 def visit_REAL(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
7286 return "REAL"
7287
7288 def visit_NUMERIC(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
7289 if type_.precision is None:
7290 return "NUMERIC"
7291 elif type_.scale is None:
7292 return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
7293 else:
7294 return "NUMERIC(%(precision)s, %(scale)s)" % {
7295 "precision": type_.precision,
7296 "scale": type_.scale,
7297 }
7298
7299 def visit_DECIMAL(self, type_: sqltypes.DECIMAL[Any], **kw: Any) -> str:
7300 if type_.precision is None:
7301 return "DECIMAL"
7302 elif type_.scale is None:
7303 return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
7304 else:
7305 return "DECIMAL(%(precision)s, %(scale)s)" % {
7306 "precision": type_.precision,
7307 "scale": type_.scale,
7308 }
7309
7310 def visit_INTEGER(self, type_: sqltypes.Integer, **kw: Any) -> str:
7311 return "INTEGER"
7312
7313 def visit_SMALLINT(self, type_: sqltypes.SmallInteger, **kw: Any) -> str:
7314 return "SMALLINT"
7315
7316 def visit_BIGINT(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
7317 return "BIGINT"
7318
7319 def visit_TIMESTAMP(self, type_: sqltypes.TIMESTAMP, **kw: Any) -> str:
7320 return "TIMESTAMP"
7321
7322 def visit_DATETIME(self, type_: sqltypes.DateTime, **kw: Any) -> str:
7323 return "DATETIME"
7324
7325 def visit_DATE(self, type_: sqltypes.Date, **kw: Any) -> str:
7326 return "DATE"
7327
7328 def visit_TIME(self, type_: sqltypes.Time, **kw: Any) -> str:
7329 return "TIME"
7330
7331 def visit_CLOB(self, type_: sqltypes.CLOB, **kw: Any) -> str:
7332 return "CLOB"
7333
7334 def visit_NCLOB(self, type_: sqltypes.Text, **kw: Any) -> str:
7335 return "NCLOB"
7336
7337 def _render_string_type(
7338 self, name: str, length: Optional[int], collation: Optional[str]
7339 ) -> str:
7340 text = name
7341 if length:
7342 text += f"({length})"
7343 if collation:
7344 text += f' COLLATE "{collation}"'
7345 return text
7346
7347 def visit_CHAR(self, type_: sqltypes.CHAR, **kw: Any) -> str:
7348 return self._render_string_type("CHAR", type_.length, type_.collation)
7349
7350 def visit_NCHAR(self, type_: sqltypes.NCHAR, **kw: Any) -> str:
7351 return self._render_string_type("NCHAR", type_.length, type_.collation)
7352
7353 def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
7354 return self._render_string_type(
7355 "VARCHAR", type_.length, type_.collation
7356 )
7357
7358 def visit_NVARCHAR(self, type_: sqltypes.NVARCHAR, **kw: Any) -> str:
7359 return self._render_string_type(
7360 "NVARCHAR", type_.length, type_.collation
7361 )
7362
7363 def visit_TEXT(self, type_: sqltypes.Text, **kw: Any) -> str:
7364 return self._render_string_type("TEXT", type_.length, type_.collation)
7365
7366 def visit_UUID(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
7367 return "UUID"
7368
7369 def visit_BLOB(self, type_: sqltypes.LargeBinary, **kw: Any) -> str:
7370 return "BLOB"
7371
7372 def visit_BINARY(self, type_: sqltypes.BINARY, **kw: Any) -> str:
7373 return "BINARY" + (type_.length and "(%d)" % type_.length or "")
7374
7375 def visit_VARBINARY(self, type_: sqltypes.VARBINARY, **kw: Any) -> str:
7376 return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
7377
7378 def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str:
7379 return "BOOLEAN"
7380
7381 def visit_uuid(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
7382 if not type_.native_uuid or not self.dialect.supports_native_uuid:
7383 return self._render_string_type("CHAR", length=32, collation=None)
7384 else:
7385 return self.visit_UUID(type_, **kw)
7386
7387 def visit_large_binary(
7388 self, type_: sqltypes.LargeBinary, **kw: Any
7389 ) -> str:
7390 return self.visit_BLOB(type_, **kw)
7391
7392 def visit_boolean(self, type_: sqltypes.Boolean, **kw: Any) -> str:
7393 return self.visit_BOOLEAN(type_, **kw)
7394
7395 def visit_time(self, type_: sqltypes.Time, **kw: Any) -> str:
7396 return self.visit_TIME(type_, **kw)
7397
7398 def visit_datetime(self, type_: sqltypes.DateTime, **kw: Any) -> str:
7399 return self.visit_DATETIME(type_, **kw)
7400
7401 def visit_date(self, type_: sqltypes.Date, **kw: Any) -> str:
7402 return self.visit_DATE(type_, **kw)
7403
7404 def visit_big_integer(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
7405 return self.visit_BIGINT(type_, **kw)
7406
7407 def visit_small_integer(
7408 self, type_: sqltypes.SmallInteger, **kw: Any
7409 ) -> str:
7410 return self.visit_SMALLINT(type_, **kw)
7411
7412 def visit_integer(self, type_: sqltypes.Integer, **kw: Any) -> str:
7413 return self.visit_INTEGER(type_, **kw)
7414
7415 def visit_real(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
7416 return self.visit_REAL(type_, **kw)
7417
7418 def visit_float(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
7419 return self.visit_FLOAT(type_, **kw)
7420
7421 def visit_double(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
7422 return self.visit_DOUBLE(type_, **kw)
7423
7424 def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
7425 return self.visit_NUMERIC(type_, **kw)
7426
7427 def visit_string(self, type_: sqltypes.String, **kw: Any) -> str:
7428 return self.visit_VARCHAR(type_, **kw)
7429
7430 def visit_unicode(self, type_: sqltypes.Unicode, **kw: Any) -> str:
7431 return self.visit_VARCHAR(type_, **kw)
7432
7433 def visit_text(self, type_: sqltypes.Text, **kw: Any) -> str:
7434 return self.visit_TEXT(type_, **kw)
7435
7436 def visit_unicode_text(
7437 self, type_: sqltypes.UnicodeText, **kw: Any
7438 ) -> str:
7439 return self.visit_TEXT(type_, **kw)
7440
7441 def visit_enum(self, type_: sqltypes.Enum, **kw: Any) -> str:
7442 return self.visit_VARCHAR(type_, **kw)
7443
7444 def visit_null(self, type_, **kw):
7445 raise exc.CompileError(
7446 "Can't generate DDL for %r; "
7447 "did you forget to specify a "
7448 "type on this Column?" % type_
7449 )
7450
7451 def visit_type_decorator(
7452 self, type_: TypeDecorator[Any], **kw: Any
7453 ) -> str:
7454 return self.process(type_.type_engine(self.dialect), **kw)
7455
7456 def visit_user_defined(
7457 self, type_: UserDefinedType[Any], **kw: Any
7458 ) -> str:
7459 return type_.get_col_spec(**kw)
7460
7461
7462class StrSQLTypeCompiler(GenericTypeCompiler):
7463 def process(self, type_, **kw):
7464 try:
7465 _compiler_dispatch = type_._compiler_dispatch
7466 except AttributeError:
7467 return self._visit_unknown(type_, **kw)
7468 else:
7469 return _compiler_dispatch(self, **kw)
7470
7471 def __getattr__(self, key):
7472 if key.startswith("visit_"):
7473 return self._visit_unknown
7474 else:
7475 raise AttributeError(key)
7476
7477 def _visit_unknown(self, type_, **kw):
7478 if type_.__class__.__name__ == type_.__class__.__name__.upper():
7479 return type_.__class__.__name__
7480 else:
7481 return repr(type_)
7482
7483 def visit_null(self, type_, **kw):
7484 return "NULL"
7485
7486 def visit_user_defined(self, type_, **kw):
7487 try:
7488 get_col_spec = type_.get_col_spec
7489 except AttributeError:
7490 return repr(type_)
7491 else:
7492 return get_col_spec(**kw)
7493
7494
7495class _SchemaForObjectCallable(Protocol):
7496 def __call__(self, obj: Any, /) -> str: ...
7497
7498
7499class _BindNameForColProtocol(Protocol):
7500 def __call__(self, col: ColumnClause[Any]) -> str: ...
7501
7502
7503class IdentifierPreparer:
7504 """Handle quoting and case-folding of identifiers based on options."""
7505
7506 reserved_words = RESERVED_WORDS
7507
7508 legal_characters = LEGAL_CHARACTERS
7509
7510 illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
7511
7512 initial_quote: str
7513
7514 final_quote: str
7515
7516 _strings: MutableMapping[str, str]
7517
7518 schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema")
7519 """Return the .schema attribute for an object.
7520
7521 For the default IdentifierPreparer, the schema for an object is always
7522 the value of the ".schema" attribute. if the preparer is replaced
7523 with one that has a non-empty schema_translate_map, the value of the
7524 ".schema" attribute is rendered a symbol that will be converted to a
7525 real schema name from the mapping post-compile.
7526
7527 """
7528
7529 _includes_none_schema_translate: bool = False
7530
7531 def __init__(
7532 self,
7533 dialect: Dialect,
7534 initial_quote: str = '"',
7535 final_quote: Optional[str] = None,
7536 escape_quote: str = '"',
7537 quote_case_sensitive_collations: bool = True,
7538 omit_schema: bool = False,
7539 ):
7540 """Construct a new ``IdentifierPreparer`` object.
7541
7542 initial_quote
7543 Character that begins a delimited identifier.
7544
7545 final_quote
7546 Character that ends a delimited identifier. Defaults to
7547 `initial_quote`.
7548
7549 omit_schema
7550 Prevent prepending schema name. Useful for databases that do
7551 not support schemae.
7552 """
7553
7554 self.dialect = dialect
7555 self.initial_quote = initial_quote
7556 self.final_quote = final_quote or self.initial_quote
7557 self.escape_quote = escape_quote
7558 self.escape_to_quote = self.escape_quote * 2
7559 self.omit_schema = omit_schema
7560 self.quote_case_sensitive_collations = quote_case_sensitive_collations
7561 self._strings = {}
7562 self._double_percents = self.dialect.paramstyle in (
7563 "format",
7564 "pyformat",
7565 )
7566
7567 def _with_schema_translate(self, schema_translate_map):
7568 prep = self.__class__.__new__(self.__class__)
7569 prep.__dict__.update(self.__dict__)
7570
7571 includes_none = None in schema_translate_map
7572
7573 def symbol_getter(obj):
7574 name = obj.schema
7575 if obj._use_schema_map and (name is not None or includes_none):
7576 if name is not None and ("[" in name or "]" in name):
7577 raise exc.CompileError(
7578 "Square bracket characters ([]) not supported "
7579 "in schema translate name '%s'" % name
7580 )
7581 return quoted_name(
7582 "__[SCHEMA_%s]" % (name or "_none"), quote=False
7583 )
7584 else:
7585 return obj.schema
7586
7587 prep.schema_for_object = symbol_getter
7588 prep._includes_none_schema_translate = includes_none
7589 return prep
7590
7591 def _render_schema_translates(
7592 self, statement: str, schema_translate_map: SchemaTranslateMapType
7593 ) -> str:
7594 d = schema_translate_map
7595 if None in d:
7596 if not self._includes_none_schema_translate:
7597 raise exc.InvalidRequestError(
7598 "schema translate map which previously did not have "
7599 "`None` present as a key now has `None` present; compiled "
7600 "statement may lack adequate placeholders. Please use "
7601 "consistent keys in successive "
7602 "schema_translate_map dictionaries."
7603 )
7604
7605 d["_none"] = d[None] # type: ignore[index]
7606
7607 def replace(m):
7608 name = m.group(2)
7609 if name in d:
7610 effective_schema = d[name]
7611 else:
7612 if name in (None, "_none"):
7613 raise exc.InvalidRequestError(
7614 "schema translate map which previously had `None` "
7615 "present as a key now no longer has it present; don't "
7616 "know how to apply schema for compiled statement. "
7617 "Please use consistent keys in successive "
7618 "schema_translate_map dictionaries."
7619 )
7620 effective_schema = name
7621
7622 if not effective_schema:
7623 effective_schema = self.dialect.default_schema_name
7624 if not effective_schema:
7625 # TODO: no coverage here
7626 raise exc.CompileError(
7627 "Dialect has no default schema name; can't "
7628 "use None as dynamic schema target."
7629 )
7630 return self.quote_schema(effective_schema)
7631
7632 return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
7633
7634 def _escape_identifier(self, value: str) -> str:
7635 """Escape an identifier.
7636
7637 Subclasses should override this to provide database-dependent
7638 escaping behavior.
7639 """
7640
7641 value = value.replace(self.escape_quote, self.escape_to_quote)
7642 if self._double_percents:
7643 value = value.replace("%", "%%")
7644 return value
7645
7646 def _unescape_identifier(self, value: str) -> str:
7647 """Canonicalize an escaped identifier.
7648
7649 Subclasses should override this to provide database-dependent
7650 unescaping behavior that reverses _escape_identifier.
7651 """
7652
7653 return value.replace(self.escape_to_quote, self.escape_quote)
7654
7655 def validate_sql_phrase(self, element, reg):
7656 """keyword sequence filter.
7657
7658 a filter for elements that are intended to represent keyword sequences,
7659 such as "INITIALLY", "INITIALLY DEFERRED", etc. no special characters
7660 should be present.
7661
7662 """
7663
7664 if element is not None and not reg.match(element):
7665 raise exc.CompileError(
7666 "Unexpected SQL phrase: %r (matching against %r)"
7667 % (element, reg.pattern)
7668 )
7669 return element
7670
7671 def quote_identifier(self, value: str) -> str:
7672 """Quote an identifier.
7673
7674 Subclasses should override this to provide database-dependent
7675 quoting behavior.
7676 """
7677
7678 return (
7679 self.initial_quote
7680 + self._escape_identifier(value)
7681 + self.final_quote
7682 )
7683
7684 def _requires_quotes(self, value: str) -> bool:
7685 """Return True if the given identifier requires quoting."""
7686 lc_value = value.lower()
7687 return (
7688 lc_value in self.reserved_words
7689 or value[0] in self.illegal_initial_characters
7690 or not self.legal_characters.match(str(value))
7691 or (lc_value != value)
7692 )
7693
7694 def _requires_quotes_illegal_chars(self, value):
7695 """Return True if the given identifier requires quoting, but
7696 not taking case convention into account."""
7697 return not self.legal_characters.match(str(value))
7698
7699 def quote_schema(self, schema: str) -> str:
7700 """Conditionally quote a schema name.
7701
7702
7703 The name is quoted if it is a reserved word, contains quote-necessary
7704 characters, or is an instance of :class:`.quoted_name` which includes
7705 ``quote`` set to ``True``.
7706
7707 Subclasses can override this to provide database-dependent
7708 quoting behavior for schema names.
7709
7710 :param schema: string schema name
7711 """
7712 return self.quote(schema)
7713
7714 def quote(self, ident: str) -> str:
7715 """Conditionally quote an identifier.
7716
7717 The identifier is quoted if it is a reserved word, contains
7718 quote-necessary characters, or is an instance of
7719 :class:`.quoted_name` which includes ``quote`` set to ``True``.
7720
7721 Subclasses can override this to provide database-dependent
7722 quoting behavior for identifier names.
7723
7724 :param ident: string identifier
7725 """
7726 force = getattr(ident, "quote", None)
7727
7728 if force is None:
7729 if ident in self._strings:
7730 return self._strings[ident]
7731 else:
7732 if self._requires_quotes(ident):
7733 self._strings[ident] = self.quote_identifier(ident)
7734 else:
7735 self._strings[ident] = ident
7736 return self._strings[ident]
7737 elif force:
7738 return self.quote_identifier(ident)
7739 else:
7740 return ident
7741
7742 def format_collation(self, collation_name):
7743 if self.quote_case_sensitive_collations:
7744 return self.quote(collation_name)
7745 else:
7746 return collation_name
7747
7748 def format_sequence(
7749 self, sequence: schema.Sequence, use_schema: bool = True
7750 ) -> str:
7751 name = self.quote(sequence.name)
7752
7753 effective_schema = self.schema_for_object(sequence)
7754
7755 if (
7756 not self.omit_schema
7757 and use_schema
7758 and effective_schema is not None
7759 ):
7760 name = self.quote_schema(effective_schema) + "." + name
7761 return name
7762
7763 def format_label(
7764 self, label: Label[Any], name: Optional[str] = None
7765 ) -> str:
7766 return self.quote(name or label.name)
7767
7768 def format_alias(
7769 self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None
7770 ) -> str:
7771 if name is None:
7772 assert alias is not None
7773 return self.quote(alias.name)
7774 else:
7775 return self.quote(name)
7776
7777 def format_savepoint(self, savepoint, name=None):
7778 # Running the savepoint name through quoting is unnecessary
7779 # for all known dialects. This is here to support potential
7780 # third party use cases
7781 ident = name or savepoint.ident
7782 if self._requires_quotes(ident):
7783 ident = self.quote_identifier(ident)
7784 return ident
7785
7786 @util.preload_module("sqlalchemy.sql.naming")
7787 def format_constraint(
7788 self, constraint: Union[Constraint, Index], _alembic_quote: bool = True
7789 ) -> Optional[str]:
7790 naming = util.preloaded.sql_naming
7791
7792 if constraint.name is _NONE_NAME:
7793 name = naming._constraint_name_for_table(
7794 constraint, constraint.table
7795 )
7796
7797 if name is None:
7798 return None
7799 else:
7800 name = constraint.name
7801
7802 assert name is not None
7803 if constraint.__visit_name__ == "index":
7804 return self.truncate_and_render_index_name(
7805 name, _alembic_quote=_alembic_quote
7806 )
7807 else:
7808 return self.truncate_and_render_constraint_name(
7809 name, _alembic_quote=_alembic_quote
7810 )
7811
7812 def truncate_and_render_index_name(
7813 self, name: str, _alembic_quote: bool = True
7814 ) -> str:
7815 # calculate these at format time so that ad-hoc changes
7816 # to dialect.max_identifier_length etc. can be reflected
7817 # as IdentifierPreparer is long lived
7818 max_ = (
7819 self.dialect.max_index_name_length
7820 or self.dialect.max_identifier_length
7821 )
7822 return self._truncate_and_render_maxlen_name(
7823 name, max_, _alembic_quote
7824 )
7825
7826 def truncate_and_render_constraint_name(
7827 self, name: str, _alembic_quote: bool = True
7828 ) -> str:
7829 # calculate these at format time so that ad-hoc changes
7830 # to dialect.max_identifier_length etc. can be reflected
7831 # as IdentifierPreparer is long lived
7832 max_ = (
7833 self.dialect.max_constraint_name_length
7834 or self.dialect.max_identifier_length
7835 )
7836 return self._truncate_and_render_maxlen_name(
7837 name, max_, _alembic_quote
7838 )
7839
7840 def _truncate_and_render_maxlen_name(
7841 self, name: str, max_: int, _alembic_quote: bool
7842 ) -> str:
7843 if isinstance(name, elements._truncated_label):
7844 if len(name) > max_:
7845 name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
7846 else:
7847 self.dialect.validate_identifier(name)
7848
7849 if not _alembic_quote:
7850 return name
7851 else:
7852 return self.quote(name)
7853
7854 def format_index(self, index: Index) -> str:
7855 name = self.format_constraint(index)
7856 assert name is not None
7857 return name
7858
7859 def format_table(
7860 self,
7861 table: FromClause,
7862 use_schema: bool = True,
7863 name: Optional[str] = None,
7864 ) -> str:
7865 """Prepare a quoted table and schema name."""
7866 if name is None:
7867 if TYPE_CHECKING:
7868 assert isinstance(table, NamedFromClause)
7869 name = table.name
7870
7871 result = self.quote(name)
7872
7873 effective_schema = self.schema_for_object(table)
7874
7875 if not self.omit_schema and use_schema and effective_schema:
7876 result = self.quote_schema(effective_schema) + "." + result
7877 return result
7878
7879 def format_schema(self, name):
7880 """Prepare a quoted schema name."""
7881
7882 return self.quote(name)
7883
7884 def format_label_name(
7885 self,
7886 name,
7887 anon_map=None,
7888 ):
7889 """Prepare a quoted column name."""
7890
7891 if anon_map is not None and isinstance(
7892 name, elements._truncated_label
7893 ):
7894 name = name.apply_map(anon_map)
7895
7896 return self.quote(name)
7897
7898 def format_column(
7899 self,
7900 column: ColumnElement[Any],
7901 use_table: bool = False,
7902 name: Optional[str] = None,
7903 table_name: Optional[str] = None,
7904 use_schema: bool = False,
7905 anon_map: Optional[Mapping[str, Any]] = None,
7906 ) -> str:
7907 """Prepare a quoted column name."""
7908
7909 if name is None:
7910 name = column.name
7911 assert name is not None
7912
7913 if anon_map is not None and isinstance(
7914 name, elements._truncated_label
7915 ):
7916 name = name.apply_map(anon_map)
7917
7918 if not getattr(column, "is_literal", False):
7919 if use_table:
7920 return (
7921 self.format_table(
7922 column.table, use_schema=use_schema, name=table_name
7923 )
7924 + "."
7925 + self.quote(name)
7926 )
7927 else:
7928 return self.quote(name)
7929 else:
7930 # literal textual elements get stuck into ColumnClause a lot,
7931 # which shouldn't get quoted
7932
7933 if use_table:
7934 return (
7935 self.format_table(
7936 column.table, use_schema=use_schema, name=table_name
7937 )
7938 + "."
7939 + name
7940 )
7941 else:
7942 return name
7943
7944 def format_table_seq(self, table, use_schema=True):
7945 """Format table name and schema as a tuple."""
7946
7947 # Dialects with more levels in their fully qualified references
7948 # ('database', 'owner', etc.) could override this and return
7949 # a longer sequence.
7950
7951 effective_schema = self.schema_for_object(table)
7952
7953 if not self.omit_schema and use_schema and effective_schema:
7954 return (
7955 self.quote_schema(effective_schema),
7956 self.format_table(table, use_schema=False),
7957 )
7958 else:
7959 return (self.format_table(table, use_schema=False),)
7960
7961 @util.memoized_property
7962 def _r_identifiers(self):
7963 initial, final, escaped_final = (
7964 re.escape(s)
7965 for s in (
7966 self.initial_quote,
7967 self.final_quote,
7968 self._escape_identifier(self.final_quote),
7969 )
7970 )
7971 r = re.compile(
7972 r"(?:"
7973 r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
7974 r"|([^\.]+))(?=\.|$))+"
7975 % {"initial": initial, "final": final, "escaped": escaped_final}
7976 )
7977 return r
7978
7979 def unformat_identifiers(self, identifiers: str) -> Sequence[str]:
7980 """Unpack 'schema.table.column'-like strings into components."""
7981
7982 r = self._r_identifiers
7983 return [
7984 self._unescape_identifier(i)
7985 for i in [a or b for a, b in r.findall(identifiers)]
7986 ]