1# sql/compiler.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Base SQL and DDL compiler implementations.
10
11Classes provided include:
12
13:class:`.compiler.SQLCompiler` - renders SQL
14strings
15
16:class:`.compiler.DDLCompiler` - renders DDL
17(data definition language) strings
18
19:class:`.compiler.GenericTypeCompiler` - renders
20type specification strings.
21
22To generate user-defined SQL strings, see
23:doc:`/ext/compiler`.
24
25"""
26from __future__ import annotations
27
28import collections
29import collections.abc as collections_abc
30import contextlib
31from enum import IntEnum
32import functools
33import itertools
34import operator
35import re
36from time import perf_counter
37import typing
38from typing import Any
39from typing import Callable
40from typing import cast
41from typing import ClassVar
42from typing import Dict
43from typing import FrozenSet
44from typing import Iterable
45from typing import Iterator
46from typing import List
47from typing import Mapping
48from typing import MutableMapping
49from typing import NamedTuple
50from typing import NoReturn
51from typing import Optional
52from typing import Pattern
53from typing import Protocol
54from typing import Sequence
55from typing import Set
56from typing import Tuple
57from typing import Type
58from typing import TYPE_CHECKING
59from typing import TypedDict
60from typing import Union
61
62from . import base
63from . import coercions
64from . import crud
65from . import elements
66from . import functions
67from . import operators
68from . import roles
69from . import schema
70from . import selectable
71from . import sqltypes
72from . import util as sql_util
73from ._typing import is_column_element
74from ._typing import is_dml
75from .base import _de_clone
76from .base import _from_objects
77from .base import _NONE_NAME
78from .base import _SentinelDefaultCharacterization
79from .base import NO_ARG
80from .elements import quoted_name
81from .sqltypes import TupleType
82from .visitors import prefix_anon_map
83from .. import exc
84from .. import util
85from ..util import FastIntFlag
86from ..util.typing import Literal
87from ..util.typing import Self
88from ..util.typing import TupleAny
89from ..util.typing import Unpack
90
91if typing.TYPE_CHECKING:
92 from .annotation import _AnnotationDict
93 from .base import _AmbiguousTableNameMap
94 from .base import CompileState
95 from .base import Executable
96 from .cache_key import CacheKey
97 from .ddl import ExecutableDDLElement
98 from .dml import Delete
99 from .dml import Insert
100 from .dml import Update
101 from .dml import UpdateBase
102 from .dml import UpdateDMLState
103 from .dml import ValuesBase
104 from .elements import _truncated_label
105 from .elements import BinaryExpression
106 from .elements import BindParameter
107 from .elements import ClauseElement
108 from .elements import ColumnClause
109 from .elements import ColumnElement
110 from .elements import False_
111 from .elements import Label
112 from .elements import Null
113 from .elements import True_
114 from .functions import Function
115 from .schema import Column
116 from .schema import Constraint
117 from .schema import ForeignKeyConstraint
118 from .schema import Index
119 from .schema import PrimaryKeyConstraint
120 from .schema import Table
121 from .schema import UniqueConstraint
122 from .selectable import _ColumnsClauseElement
123 from .selectable import AliasedReturnsRows
124 from .selectable import CompoundSelectState
125 from .selectable import CTE
126 from .selectable import FromClause
127 from .selectable import NamedFromClause
128 from .selectable import ReturnsRows
129 from .selectable import Select
130 from .selectable import SelectState
131 from .type_api import _BindProcessorType
132 from .type_api import TypeDecorator
133 from .type_api import TypeEngine
134 from .type_api import UserDefinedType
135 from .visitors import Visitable
136 from ..engine.cursor import CursorResultMetaData
137 from ..engine.interfaces import _CoreSingleExecuteParams
138 from ..engine.interfaces import _DBAPIAnyExecuteParams
139 from ..engine.interfaces import _DBAPIMultiExecuteParams
140 from ..engine.interfaces import _DBAPISingleExecuteParams
141 from ..engine.interfaces import _ExecuteOptions
142 from ..engine.interfaces import _GenericSetInputSizesType
143 from ..engine.interfaces import _MutableCoreSingleExecuteParams
144 from ..engine.interfaces import Dialect
145 from ..engine.interfaces import SchemaTranslateMapType
146
147
148_FromHintsType = Dict["FromClause", str]
149
150RESERVED_WORDS = {
151 "all",
152 "analyse",
153 "analyze",
154 "and",
155 "any",
156 "array",
157 "as",
158 "asc",
159 "asymmetric",
160 "authorization",
161 "between",
162 "binary",
163 "both",
164 "case",
165 "cast",
166 "check",
167 "collate",
168 "column",
169 "constraint",
170 "create",
171 "cross",
172 "current_date",
173 "current_role",
174 "current_time",
175 "current_timestamp",
176 "current_user",
177 "default",
178 "deferrable",
179 "desc",
180 "distinct",
181 "do",
182 "else",
183 "end",
184 "except",
185 "false",
186 "for",
187 "foreign",
188 "freeze",
189 "from",
190 "full",
191 "grant",
192 "group",
193 "having",
194 "ilike",
195 "in",
196 "initially",
197 "inner",
198 "intersect",
199 "into",
200 "is",
201 "isnull",
202 "join",
203 "leading",
204 "left",
205 "like",
206 "limit",
207 "localtime",
208 "localtimestamp",
209 "natural",
210 "new",
211 "not",
212 "notnull",
213 "null",
214 "off",
215 "offset",
216 "old",
217 "on",
218 "only",
219 "or",
220 "order",
221 "outer",
222 "overlaps",
223 "placing",
224 "primary",
225 "references",
226 "right",
227 "select",
228 "session_user",
229 "set",
230 "similar",
231 "some",
232 "symmetric",
233 "table",
234 "then",
235 "to",
236 "trailing",
237 "true",
238 "union",
239 "unique",
240 "user",
241 "using",
242 "verbose",
243 "when",
244 "where",
245}
246
247LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
248LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
249ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
250
251FK_ON_DELETE = re.compile(
252 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
253)
254FK_ON_UPDATE = re.compile(
255 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
256)
257FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
258BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
259BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
260
261_pyformat_template = "%%(%(name)s)s"
262BIND_TEMPLATES = {
263 "pyformat": _pyformat_template,
264 "qmark": "?",
265 "format": "%%s",
266 "numeric": ":[_POSITION]",
267 "numeric_dollar": "$[_POSITION]",
268 "named": ":%(name)s",
269}
270
271
272OPERATORS = {
273 # binary
274 operators.and_: " AND ",
275 operators.or_: " OR ",
276 operators.add: " + ",
277 operators.mul: " * ",
278 operators.sub: " - ",
279 operators.mod: " % ",
280 operators.neg: "-",
281 operators.lt: " < ",
282 operators.le: " <= ",
283 operators.ne: " != ",
284 operators.gt: " > ",
285 operators.ge: " >= ",
286 operators.eq: " = ",
287 operators.is_distinct_from: " IS DISTINCT FROM ",
288 operators.is_not_distinct_from: " IS NOT DISTINCT FROM ",
289 operators.concat_op: " || ",
290 operators.match_op: " MATCH ",
291 operators.not_match_op: " NOT MATCH ",
292 operators.in_op: " IN ",
293 operators.not_in_op: " NOT IN ",
294 operators.comma_op: ", ",
295 operators.from_: " FROM ",
296 operators.as_: " AS ",
297 operators.is_: " IS ",
298 operators.is_not: " IS NOT ",
299 operators.collate: " COLLATE ",
300 # unary
301 operators.exists: "EXISTS ",
302 operators.distinct_op: "DISTINCT ",
303 operators.inv: "NOT ",
304 operators.any_op: "ANY ",
305 operators.all_op: "ALL ",
306 # modifiers
307 operators.desc_op: " DESC",
308 operators.asc_op: " ASC",
309 operators.nulls_first_op: " NULLS FIRST",
310 operators.nulls_last_op: " NULLS LAST",
311 # bitwise
312 operators.bitwise_xor_op: " ^ ",
313 operators.bitwise_or_op: " | ",
314 operators.bitwise_and_op: " & ",
315 operators.bitwise_not_op: "~",
316 operators.bitwise_lshift_op: " << ",
317 operators.bitwise_rshift_op: " >> ",
318}
319
320FUNCTIONS: Dict[Type[Function[Any]], str] = {
321 functions.coalesce: "coalesce",
322 functions.current_date: "CURRENT_DATE",
323 functions.current_time: "CURRENT_TIME",
324 functions.current_timestamp: "CURRENT_TIMESTAMP",
325 functions.current_user: "CURRENT_USER",
326 functions.localtime: "LOCALTIME",
327 functions.localtimestamp: "LOCALTIMESTAMP",
328 functions.random: "random",
329 functions.sysdate: "sysdate",
330 functions.session_user: "SESSION_USER",
331 functions.user: "USER",
332 functions.cube: "CUBE",
333 functions.rollup: "ROLLUP",
334 functions.grouping_sets: "GROUPING SETS",
335}
336
337
338EXTRACT_MAP = {
339 "month": "month",
340 "day": "day",
341 "year": "year",
342 "second": "second",
343 "hour": "hour",
344 "doy": "doy",
345 "minute": "minute",
346 "quarter": "quarter",
347 "dow": "dow",
348 "week": "week",
349 "epoch": "epoch",
350 "milliseconds": "milliseconds",
351 "microseconds": "microseconds",
352 "timezone_hour": "timezone_hour",
353 "timezone_minute": "timezone_minute",
354}
355
356COMPOUND_KEYWORDS = {
357 selectable._CompoundSelectKeyword.UNION: "UNION",
358 selectable._CompoundSelectKeyword.UNION_ALL: "UNION ALL",
359 selectable._CompoundSelectKeyword.EXCEPT: "EXCEPT",
360 selectable._CompoundSelectKeyword.EXCEPT_ALL: "EXCEPT ALL",
361 selectable._CompoundSelectKeyword.INTERSECT: "INTERSECT",
362 selectable._CompoundSelectKeyword.INTERSECT_ALL: "INTERSECT ALL",
363}
364
365
366class ResultColumnsEntry(NamedTuple):
367 """Tracks a column expression that is expected to be represented
368 in the result rows for this statement.
369
370 This normally refers to the columns clause of a SELECT statement
371 but may also refer to a RETURNING clause, as well as for dialect-specific
372 emulations.
373
374 """
375
376 keyname: str
377 """string name that's expected in cursor.description"""
378
379 name: str
380 """column name, may be labeled"""
381
382 objects: Tuple[Any, ...]
383 """sequence of objects that should be able to locate this column
384 in a RowMapping. This is typically string names and aliases
385 as well as Column objects.
386
387 """
388
389 type: TypeEngine[Any]
390 """Datatype to be associated with this column. This is where
391 the "result processing" logic directly links the compiled statement
392 to the rows that come back from the cursor.
393
394 """
395
396
397class _ResultMapAppender(Protocol):
398 def __call__(
399 self,
400 keyname: str,
401 name: str,
402 objects: Sequence[Any],
403 type_: TypeEngine[Any],
404 ) -> None: ...
405
406
407# integer indexes into ResultColumnsEntry used by cursor.py.
408# some profiling showed integer access faster than named tuple
409RM_RENDERED_NAME: Literal[0] = 0
410RM_NAME: Literal[1] = 1
411RM_OBJECTS: Literal[2] = 2
412RM_TYPE: Literal[3] = 3
413
414
415class _BaseCompilerStackEntry(TypedDict):
416 asfrom_froms: Set[FromClause]
417 correlate_froms: Set[FromClause]
418 selectable: ReturnsRows
419
420
421class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
422 compile_state: CompileState
423 need_result_map_for_nested: bool
424 need_result_map_for_compound: bool
425 select_0: ReturnsRows
426 insert_from_select: Select[Unpack[TupleAny]]
427
428
429class ExpandedState(NamedTuple):
430 """represents state to use when producing "expanded" and
431 "post compile" bound parameters for a statement.
432
433 "expanded" parameters are parameters that are generated at
434 statement execution time to suit a number of parameters passed, the most
435 prominent example being the individual elements inside of an IN expression.
436
437 "post compile" parameters are parameters where the SQL literal value
438 will be rendered into the SQL statement at execution time, rather than
439 being passed as separate parameters to the driver.
440
441 To create an :class:`.ExpandedState` instance, use the
442 :meth:`.SQLCompiler.construct_expanded_state` method on any
443 :class:`.SQLCompiler` instance.
444
445 """
446
447 statement: str
448 """String SQL statement with parameters fully expanded"""
449
450 parameters: _CoreSingleExecuteParams
451 """Parameter dictionary with parameters fully expanded.
452
453 For a statement that uses named parameters, this dictionary will map
454 exactly to the names in the statement. For a statement that uses
455 positional parameters, the :attr:`.ExpandedState.positional_parameters`
456 will yield a tuple with the positional parameter set.
457
458 """
459
460 processors: Mapping[str, _BindProcessorType[Any]]
461 """mapping of bound value processors"""
462
463 positiontup: Optional[Sequence[str]]
464 """Sequence of string names indicating the order of positional
465 parameters"""
466
467 parameter_expansion: Mapping[str, List[str]]
468 """Mapping representing the intermediary link from original parameter
469 name to list of "expanded" parameter names, for those parameters that
470 were expanded."""
471
472 @property
473 def positional_parameters(self) -> Tuple[Any, ...]:
474 """Tuple of positional parameters, for statements that were compiled
475 using a positional paramstyle.
476
477 """
478 if self.positiontup is None:
479 raise exc.InvalidRequestError(
480 "statement does not use a positional paramstyle"
481 )
482 return tuple(self.parameters[key] for key in self.positiontup)
483
484 @property
485 def additional_parameters(self) -> _CoreSingleExecuteParams:
486 """synonym for :attr:`.ExpandedState.parameters`."""
487 return self.parameters
488
489
490class _InsertManyValues(NamedTuple):
491 """represents state to use for executing an "insertmanyvalues" statement.
492
493 The primary consumers of this object are the
494 :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and
495 :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods.
496
497 .. versionadded:: 2.0
498
499 """
500
501 is_default_expr: bool
502 """if True, the statement is of the form
503 ``INSERT INTO TABLE DEFAULT VALUES``, and can't be rewritten as a "batch"
504
505 """
506
507 single_values_expr: str
508 """The rendered "values" clause of the INSERT statement.
509
510 This is typically the parenthesized section e.g. "(?, ?, ?)" or similar.
511 The insertmanyvalues logic uses this string as a search and replace
512 target.
513
514 """
515
516 insert_crud_params: List[crud._CrudParamElementStr]
517 """List of Column / bind names etc. used while rewriting the statement"""
518
519 num_positional_params_counted: int
520 """the number of bound parameters in a single-row statement.
521
522 This count may be larger or smaller than the actual number of columns
523 targeted in the INSERT, as it accommodates for SQL expressions
524 in the values list that may have zero or more parameters embedded
525 within them.
526
527 This count is part of what's used to organize rewritten parameter lists
528 when batching.
529
530 """
531
532 sort_by_parameter_order: bool = False
533 """if the deterministic_returnined_order parameter were used on the
534 insert.
535
536 All of the attributes following this will only be used if this is True.
537
538 """
539
540 includes_upsert_behaviors: bool = False
541 """if True, we have to accommodate for upsert behaviors.
542
543 This will in some cases downgrade "insertmanyvalues" that requests
544 deterministic ordering.
545
546 """
547
548 sentinel_columns: Optional[Sequence[Column[Any]]] = None
549 """List of sentinel columns that were located.
550
551 This list is only here if the INSERT asked for
552 sort_by_parameter_order=True,
553 and dialect-appropriate sentinel columns were located.
554
555 .. versionadded:: 2.0.10
556
557 """
558
559 num_sentinel_columns: int = 0
560 """how many sentinel columns are in the above list, if any.
561
562 This is the same as
563 ``len(sentinel_columns) if sentinel_columns is not None else 0``
564
565 """
566
567 sentinel_param_keys: Optional[Sequence[str]] = None
568 """parameter str keys in each param dictionary / tuple
569 that would link to the client side "sentinel" values for that row, which
570 we can use to match up parameter sets to result rows.
571
572 This is only present if sentinel_columns is present and the INSERT
573 statement actually refers to client side values for these sentinel
574 columns.
575
576 .. versionadded:: 2.0.10
577
578 .. versionchanged:: 2.0.29 - the sequence is now string dictionary keys
579 only, used against the "compiled parameteters" collection before
580 the parameters were converted by bound parameter processors
581
582 """
583
584 implicit_sentinel: bool = False
585 """if True, we have exactly one sentinel column and it uses a server side
586 value, currently has to generate an incrementing integer value.
587
588 The dialect in question would have asserted that it supports receiving
589 these values back and sorting on that value as a means of guaranteeing
590 correlation with the incoming parameter list.
591
592 .. versionadded:: 2.0.10
593
594 """
595
596 embed_values_counter: bool = False
597 """Whether to embed an incrementing integer counter in each parameter
598 set within the VALUES clause as parameters are batched over.
599
600 This is only used for a specific INSERT..SELECT..VALUES..RETURNING syntax
601 where a subquery is used to produce value tuples. Current support
602 includes PostgreSQL, Microsoft SQL Server.
603
604 .. versionadded:: 2.0.10
605
606 """
607
608
609class _InsertManyValuesBatch(NamedTuple):
610 """represents an individual batch SQL statement for insertmanyvalues.
611
612 This is passed through the
613 :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and
614 :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods out
615 to the :class:`.Connection` within the
616 :meth:`.Connection._exec_insertmany_context` method.
617
618 .. versionadded:: 2.0.10
619
620 """
621
622 replaced_statement: str
623 replaced_parameters: _DBAPIAnyExecuteParams
624 processed_setinputsizes: Optional[_GenericSetInputSizesType]
625 batch: Sequence[_DBAPISingleExecuteParams]
626 sentinel_values: Sequence[Tuple[Any, ...]]
627 current_batch_size: int
628 batchnum: int
629 total_batches: int
630 rows_sorted: bool
631 is_downgraded: bool
632
633
634class InsertmanyvaluesSentinelOpts(FastIntFlag):
635 """bitflag enum indicating styles of PK defaults
636 which can work as implicit sentinel columns
637
638 """
639
640 NOT_SUPPORTED = 1
641 AUTOINCREMENT = 2
642 IDENTITY = 4
643 SEQUENCE = 8
644
645 ANY_AUTOINCREMENT = AUTOINCREMENT | IDENTITY | SEQUENCE
646 _SUPPORTED_OR_NOT = NOT_SUPPORTED | ANY_AUTOINCREMENT
647
648 USE_INSERT_FROM_SELECT = 16
649 RENDER_SELECT_COL_CASTS = 64
650
651
652class CompilerState(IntEnum):
653 COMPILING = 0
654 """statement is present, compilation phase in progress"""
655
656 STRING_APPLIED = 1
657 """statement is present, string form of the statement has been applied.
658
659 Additional processors by subclasses may still be pending.
660
661 """
662
663 NO_STATEMENT = 2
664 """compiler does not have a statement to compile, is used
665 for method access"""
666
667
668class Linting(IntEnum):
669 """represent preferences for the 'SQL linting' feature.
670
671 this feature currently includes support for flagging cartesian products
672 in SQL statements.
673
674 """
675
676 NO_LINTING = 0
677 "Disable all linting."
678
679 COLLECT_CARTESIAN_PRODUCTS = 1
680 """Collect data on FROMs and cartesian products and gather into
681 'self.from_linter'"""
682
683 WARN_LINTING = 2
684 "Emit warnings for linters that find problems"
685
686 FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING
687 """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS
688 and WARN_LINTING"""
689
690
691NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple(
692 Linting
693)
694
695
696class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
697 """represents current state for the "cartesian product" detection
698 feature."""
699
700 def lint(self, start=None):
701 froms = self.froms
702 if not froms:
703 return None, None
704
705 edges = set(self.edges)
706 the_rest = set(froms)
707
708 if start is not None:
709 start_with = start
710 the_rest.remove(start_with)
711 else:
712 start_with = the_rest.pop()
713
714 stack = collections.deque([start_with])
715
716 while stack and the_rest:
717 node = stack.popleft()
718 the_rest.discard(node)
719
720 # comparison of nodes in edges here is based on hash equality, as
721 # there are "annotated" elements that match the non-annotated ones.
722 # to remove the need for in-python hash() calls, use native
723 # containment routines (e.g. "node in edge", "edge.index(node)")
724 to_remove = {edge for edge in edges if node in edge}
725
726 # appendleft the node in each edge that is not
727 # the one that matched.
728 stack.extendleft(edge[not edge.index(node)] for edge in to_remove)
729 edges.difference_update(to_remove)
730
731 # FROMS left over? boom
732 if the_rest:
733 return the_rest, start_with
734 else:
735 return None, None
736
737 def warn(self, stmt_type="SELECT"):
738 the_rest, start_with = self.lint()
739
740 # FROMS left over? boom
741 if the_rest:
742 froms = the_rest
743 if froms:
744 template = (
745 "{stmt_type} statement has a cartesian product between "
746 "FROM element(s) {froms} and "
747 'FROM element "{start}". Apply join condition(s) '
748 "between each element to resolve."
749 )
750 froms_str = ", ".join(
751 f'"{self.froms[from_]}"' for from_ in froms
752 )
753 message = template.format(
754 stmt_type=stmt_type,
755 froms=froms_str,
756 start=self.froms[start_with],
757 )
758
759 util.warn(message)
760
761
762class Compiled:
763 """Represent a compiled SQL or DDL expression.
764
765 The ``__str__`` method of the ``Compiled`` object should produce
766 the actual text of the statement. ``Compiled`` objects are
767 specific to their underlying database dialect, and also may
768 or may not be specific to the columns referenced within a
769 particular set of bind parameters. In no case should the
770 ``Compiled`` object be dependent on the actual values of those
771 bind parameters, even though it may reference those values as
772 defaults.
773 """
774
775 statement: Optional[ClauseElement] = None
776 "The statement to compile."
777 string: str = ""
778 "The string representation of the ``statement``"
779
780 state: CompilerState
781 """description of the compiler's state"""
782
783 is_sql = False
784 is_ddl = False
785
786 _cached_metadata: Optional[CursorResultMetaData] = None
787
788 _result_columns: Optional[List[ResultColumnsEntry]] = None
789
790 schema_translate_map: Optional[SchemaTranslateMapType] = None
791
792 execution_options: _ExecuteOptions = util.EMPTY_DICT
793 """
794 Execution options propagated from the statement. In some cases,
795 sub-elements of the statement can modify these.
796 """
797
798 preparer: IdentifierPreparer
799
800 _annotations: _AnnotationDict = util.EMPTY_DICT
801
802 compile_state: Optional[CompileState] = None
803 """Optional :class:`.CompileState` object that maintains additional
804 state used by the compiler.
805
806 Major executable objects such as :class:`_expression.Insert`,
807 :class:`_expression.Update`, :class:`_expression.Delete`,
808 :class:`_expression.Select` will generate this
809 state when compiled in order to calculate additional information about the
810 object. For the top level object that is to be executed, the state can be
811 stored here where it can also have applicability towards result set
812 processing.
813
814 .. versionadded:: 1.4
815
816 """
817
818 dml_compile_state: Optional[CompileState] = None
819 """Optional :class:`.CompileState` assigned at the same point that
820 .isinsert, .isupdate, or .isdelete is assigned.
821
822 This will normally be the same object as .compile_state, with the
823 exception of cases like the :class:`.ORMFromStatementCompileState`
824 object.
825
826 .. versionadded:: 1.4.40
827
828 """
829
830 cache_key: Optional[CacheKey] = None
831 """The :class:`.CacheKey` that was generated ahead of creating this
832 :class:`.Compiled` object.
833
834 This is used for routines that need access to the original
835 :class:`.CacheKey` instance generated when the :class:`.Compiled`
836 instance was first cached, typically in order to reconcile
837 the original list of :class:`.BindParameter` objects with a
838 per-statement list that's generated on each call.
839
840 """
841
842 _gen_time: float
843 """Generation time of this :class:`.Compiled`, used for reporting
844 cache stats."""
845
846 def __init__(
847 self,
848 dialect: Dialect,
849 statement: Optional[ClauseElement],
850 schema_translate_map: Optional[SchemaTranslateMapType] = None,
851 render_schema_translate: bool = False,
852 compile_kwargs: Mapping[str, Any] = util.immutabledict(),
853 ):
854 """Construct a new :class:`.Compiled` object.
855
856 :param dialect: :class:`.Dialect` to compile against.
857
858 :param statement: :class:`_expression.ClauseElement` to be compiled.
859
860 :param schema_translate_map: dictionary of schema names to be
861 translated when forming the resultant SQL
862
863 .. seealso::
864
865 :ref:`schema_translating`
866
867 :param compile_kwargs: additional kwargs that will be
868 passed to the initial call to :meth:`.Compiled.process`.
869
870
871 """
872 self.dialect = dialect
873 self.preparer = self.dialect.identifier_preparer
874 if schema_translate_map:
875 self.schema_translate_map = schema_translate_map
876 self.preparer = self.preparer._with_schema_translate(
877 schema_translate_map
878 )
879
880 if statement is not None:
881 self.state = CompilerState.COMPILING
882 self.statement = statement
883 self.can_execute = statement.supports_execution
884 self._annotations = statement._annotations
885 if self.can_execute:
886 if TYPE_CHECKING:
887 assert isinstance(statement, Executable)
888 self.execution_options = statement._execution_options
889 self.string = self.process(self.statement, **compile_kwargs)
890
891 if render_schema_translate:
892 assert schema_translate_map is not None
893 self.string = self.preparer._render_schema_translates(
894 self.string, schema_translate_map
895 )
896
897 self.state = CompilerState.STRING_APPLIED
898 else:
899 self.state = CompilerState.NO_STATEMENT
900
901 self._gen_time = perf_counter()
902
903 def __init_subclass__(cls) -> None:
904 cls._init_compiler_cls()
905 return super().__init_subclass__()
906
907 @classmethod
908 def _init_compiler_cls(cls):
909 pass
910
911 def _execute_on_connection(
912 self, connection, distilled_params, execution_options
913 ):
914 if self.can_execute:
915 return connection._execute_compiled(
916 self, distilled_params, execution_options
917 )
918 else:
919 raise exc.ObjectNotExecutableError(self.statement)
920
921 def visit_unsupported_compilation(self, element, err, **kw):
922 raise exc.UnsupportedCompilationError(self, type(element)) from err
923
924 @property
925 def sql_compiler(self) -> SQLCompiler:
926 """Return a Compiled that is capable of processing SQL expressions.
927
928 If this compiler is one, it would likely just return 'self'.
929
930 """
931
932 raise NotImplementedError()
933
934 def process(self, obj: Visitable, **kwargs: Any) -> str:
935 return obj._compiler_dispatch(self, **kwargs)
936
937 def __str__(self) -> str:
938 """Return the string text of the generated SQL or DDL."""
939
940 if self.state is CompilerState.STRING_APPLIED:
941 return self.string
942 else:
943 return ""
944
945 def construct_params(
946 self,
947 params: Optional[_CoreSingleExecuteParams] = None,
948 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
949 escape_names: bool = True,
950 ) -> Optional[_MutableCoreSingleExecuteParams]:
951 """Return the bind params for this compiled object.
952
953 :param params: a dict of string/object pairs whose values will
954 override bind values compiled in to the
955 statement.
956 """
957
958 raise NotImplementedError()
959
960 @property
961 def params(self):
962 """Return the bind params for this compiled object."""
963 return self.construct_params()
964
965
966class TypeCompiler(util.EnsureKWArg):
967 """Produces DDL specification for TypeEngine objects."""
968
969 ensure_kwarg = r"visit_\w+"
970
971 def __init__(self, dialect: Dialect):
972 self.dialect = dialect
973
974 def process(self, type_: TypeEngine[Any], **kw: Any) -> str:
975 if (
976 type_._variant_mapping
977 and self.dialect.name in type_._variant_mapping
978 ):
979 type_ = type_._variant_mapping[self.dialect.name]
980 return type_._compiler_dispatch(self, **kw)
981
982 def visit_unsupported_compilation(
983 self, element: Any, err: Exception, **kw: Any
984 ) -> NoReturn:
985 raise exc.UnsupportedCompilationError(self, element) from err
986
987
988# this was a Visitable, but to allow accurate detection of
989# column elements this is actually a column element
990class _CompileLabel(
991 roles.BinaryElementRole[Any], elements.CompilerColumnElement
992):
993 """lightweight label object which acts as an expression.Label."""
994
995 __visit_name__ = "label"
996 __slots__ = "element", "name", "_alt_names"
997
998 def __init__(self, col, name, alt_names=()):
999 self.element = col
1000 self.name = name
1001 self._alt_names = (col,) + alt_names
1002
1003 @property
1004 def proxy_set(self):
1005 return self.element.proxy_set
1006
1007 @property
1008 def type(self):
1009 return self.element.type
1010
1011 def self_group(self, **kw):
1012 return self
1013
1014
1015class ilike_case_insensitive(
1016 roles.BinaryElementRole[Any], elements.CompilerColumnElement
1017):
1018 """produce a wrapping element for a case-insensitive portion of
1019 an ILIKE construct.
1020
1021 The construct usually renders the ``lower()`` function, but on
1022 PostgreSQL will pass silently with the assumption that "ILIKE"
1023 is being used.
1024
1025 .. versionadded:: 2.0
1026
1027 """
1028
1029 __visit_name__ = "ilike_case_insensitive_operand"
1030 __slots__ = "element", "comparator"
1031
1032 def __init__(self, element):
1033 self.element = element
1034 self.comparator = element.comparator
1035
1036 @property
1037 def proxy_set(self):
1038 return self.element.proxy_set
1039
1040 @property
1041 def type(self):
1042 return self.element.type
1043
1044 def self_group(self, **kw):
1045 return self
1046
1047 def _with_binary_element_type(self, type_):
1048 return ilike_case_insensitive(
1049 self.element._with_binary_element_type(type_)
1050 )
1051
1052
1053class SQLCompiler(Compiled):
1054 """Default implementation of :class:`.Compiled`.
1055
1056 Compiles :class:`_expression.ClauseElement` objects into SQL strings.
1057
1058 """
1059
1060 extract_map = EXTRACT_MAP
1061
1062 bindname_escape_characters: ClassVar[Mapping[str, str]] = (
1063 util.immutabledict(
1064 {
1065 "%": "P",
1066 "(": "A",
1067 ")": "Z",
1068 ":": "C",
1069 ".": "_",
1070 "[": "_",
1071 "]": "_",
1072 " ": "_",
1073 }
1074 )
1075 )
1076 """A mapping (e.g. dict or similar) containing a lookup of
1077 characters keyed to replacement characters which will be applied to all
1078 'bind names' used in SQL statements as a form of 'escaping'; the given
1079 characters are replaced entirely with the 'replacement' character when
1080 rendered in the SQL statement, and a similar translation is performed
1081 on the incoming names used in parameter dictionaries passed to methods
1082 like :meth:`_engine.Connection.execute`.
1083
1084 This allows bound parameter names used in :func:`_sql.bindparam` and
1085 other constructs to have any arbitrary characters present without any
1086 concern for characters that aren't allowed at all on the target database.
1087
1088 Third party dialects can establish their own dictionary here to replace the
1089 default mapping, which will ensure that the particular characters in the
1090 mapping will never appear in a bound parameter name.
1091
1092 The dictionary is evaluated at **class creation time**, so cannot be
1093 modified at runtime; it must be present on the class when the class
1094 is first declared.
1095
1096 Note that for dialects that have additional bound parameter rules such
1097 as additional restrictions on leading characters, the
1098 :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented.
1099 See the cx_Oracle compiler for an example of this.
1100
1101 .. versionadded:: 2.0.0rc1
1102
1103 """
1104
1105 _bind_translate_re: ClassVar[Pattern[str]]
1106 _bind_translate_chars: ClassVar[Mapping[str, str]]
1107
1108 is_sql = True
1109
1110 compound_keywords = COMPOUND_KEYWORDS
1111
1112 isdelete: bool = False
1113 isinsert: bool = False
1114 isupdate: bool = False
1115 """class-level defaults which can be set at the instance
1116 level to define if this Compiled instance represents
1117 INSERT/UPDATE/DELETE
1118 """
1119
1120 postfetch: Optional[List[Column[Any]]]
1121 """list of columns that can be post-fetched after INSERT or UPDATE to
1122 receive server-updated values"""
1123
1124 insert_prefetch: Sequence[Column[Any]] = ()
1125 """list of columns for which default values should be evaluated before
1126 an INSERT takes place"""
1127
1128 update_prefetch: Sequence[Column[Any]] = ()
1129 """list of columns for which onupdate default values should be evaluated
1130 before an UPDATE takes place"""
1131
1132 implicit_returning: Optional[Sequence[ColumnElement[Any]]] = None
1133 """list of "implicit" returning columns for a toplevel INSERT or UPDATE
1134 statement, used to receive newly generated values of columns.
1135
1136 .. versionadded:: 2.0 ``implicit_returning`` replaces the previous
1137 ``returning`` collection, which was not a generalized RETURNING
1138 collection and instead was in fact specific to the "implicit returning"
1139 feature.
1140
1141 """
1142
1143 isplaintext: bool = False
1144
1145 binds: Dict[str, BindParameter[Any]]
1146 """a dictionary of bind parameter keys to BindParameter instances."""
1147
1148 bind_names: Dict[BindParameter[Any], str]
1149 """a dictionary of BindParameter instances to "compiled" names
1150 that are actually present in the generated SQL"""
1151
1152 stack: List[_CompilerStackEntry]
1153 """major statements such as SELECT, INSERT, UPDATE, DELETE are
1154 tracked in this stack using an entry format."""
1155
1156 returning_precedes_values: bool = False
1157 """set to True classwide to generate RETURNING
1158 clauses before the VALUES or WHERE clause (i.e. MSSQL)
1159 """
1160
1161 render_table_with_column_in_update_from: bool = False
1162 """set to True classwide to indicate the SET clause
1163 in a multi-table UPDATE statement should qualify
1164 columns with the table name (i.e. MySQL only)
1165 """
1166
1167 ansi_bind_rules: bool = False
1168 """SQL 92 doesn't allow bind parameters to be used
1169 in the columns clause of a SELECT, nor does it allow
1170 ambiguous expressions like "? = ?". A compiler
1171 subclass can set this flag to False if the target
1172 driver/DB enforces this
1173 """
1174
1175 bindtemplate: str
1176 """template to render bound parameters based on paramstyle."""
1177
1178 compilation_bindtemplate: str
1179 """template used by compiler to render parameters before positional
1180 paramstyle application"""
1181
1182 _numeric_binds_identifier_char: str
1183 """Character that's used to as the identifier of a numerical bind param.
1184 For example if this char is set to ``$``, numerical binds will be rendered
1185 in the form ``$1, $2, $3``.
1186 """
1187
1188 _result_columns: List[ResultColumnsEntry]
1189 """relates label names in the final SQL to a tuple of local
1190 column/label name, ColumnElement object (if any) and
1191 TypeEngine. CursorResult uses this for type processing and
1192 column targeting"""
1193
1194 _textual_ordered_columns: bool = False
1195 """tell the result object that the column names as rendered are important,
1196 but they are also "ordered" vs. what is in the compiled object here.
1197
1198 As of 1.4.42 this condition is only present when the statement is a
1199 TextualSelect, e.g. text("....").columns(...), where it is required
1200 that the columns are considered positionally and not by name.
1201
1202 """
1203
1204 _ad_hoc_textual: bool = False
1205 """tell the result that we encountered text() or '*' constructs in the
1206 middle of the result columns, but we also have compiled columns, so
1207 if the number of columns in cursor.description does not match how many
1208 expressions we have, that means we can't rely on positional at all and
1209 should match on name.
1210
1211 """
1212
1213 _ordered_columns: bool = True
1214 """
1215 if False, means we can't be sure the list of entries
1216 in _result_columns is actually the rendered order. Usually
1217 True unless using an unordered TextualSelect.
1218 """
1219
1220 _loose_column_name_matching: bool = False
1221 """tell the result object that the SQL statement is textual, wants to match
1222 up to Column objects, and may be using the ._tq_label in the SELECT rather
1223 than the base name.
1224
1225 """
1226
1227 _numeric_binds: bool = False
1228 """
1229 True if paramstyle is "numeric". This paramstyle is trickier than
1230 all the others.
1231
1232 """
1233
1234 _render_postcompile: bool = False
1235 """
1236 whether to render out POSTCOMPILE params during the compile phase.
1237
1238 This attribute is used only for end-user invocation of stmt.compile();
1239 it's never used for actual statement execution, where instead the
1240 dialect internals access and render the internal postcompile structure
1241 directly.
1242
1243 """
1244
1245 _post_compile_expanded_state: Optional[ExpandedState] = None
1246 """When render_postcompile is used, the ``ExpandedState`` used to create
1247 the "expanded" SQL is assigned here, and then used by the ``.params``
1248 accessor and ``.construct_params()`` methods for their return values.
1249
1250 .. versionadded:: 2.0.0rc1
1251
1252 """
1253
1254 _pre_expanded_string: Optional[str] = None
1255 """Stores the original string SQL before 'post_compile' is applied,
1256 for cases where 'post_compile' were used.
1257
1258 """
1259
1260 _pre_expanded_positiontup: Optional[List[str]] = None
1261
1262 _insertmanyvalues: Optional[_InsertManyValues] = None
1263
1264 _insert_crud_params: Optional[crud._CrudParamSequence] = None
1265
1266 literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
1267 """bindparameter objects that are rendered as literal values at statement
1268 execution time.
1269
1270 """
1271
1272 post_compile_params: FrozenSet[BindParameter[Any]] = frozenset()
1273 """bindparameter objects that are rendered as bound parameter placeholders
1274 at statement execution time.
1275
1276 """
1277
1278 escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT
1279 """Late escaping of bound parameter names that has to be converted
1280 to the original name when looking in the parameter dictionary.
1281
1282 """
1283
1284 has_out_parameters = False
1285 """if True, there are bindparam() objects that have the isoutparam
1286 flag set."""
1287
1288 postfetch_lastrowid = False
1289 """if True, and this in insert, use cursor.lastrowid to populate
1290 result.inserted_primary_key. """
1291
1292 _cache_key_bind_match: Optional[
1293 Tuple[
1294 Dict[
1295 BindParameter[Any],
1296 List[BindParameter[Any]],
1297 ],
1298 Dict[
1299 str,
1300 BindParameter[Any],
1301 ],
1302 ]
1303 ] = None
1304 """a mapping that will relate the BindParameter object we compile
1305 to those that are part of the extracted collection of parameters
1306 in the cache key, if we were given a cache key.
1307
1308 """
1309
1310 positiontup: Optional[List[str]] = None
1311 """for a compiled construct that uses a positional paramstyle, will be
1312 a sequence of strings, indicating the names of bound parameters in order.
1313
1314 This is used in order to render bound parameters in their correct order,
1315 and is combined with the :attr:`_sql.Compiled.params` dictionary to
1316 render parameters.
1317
1318 This sequence always contains the unescaped name of the parameters.
1319
1320 .. seealso::
1321
1322 :ref:`faq_sql_expression_string` - includes a usage example for
1323 debugging use cases.
1324
1325 """
1326 _values_bindparam: Optional[List[str]] = None
1327
1328 _visited_bindparam: Optional[List[str]] = None
1329
1330 inline: bool = False
1331
1332 ctes: Optional[MutableMapping[CTE, str]]
1333
1334 # Detect same CTE references - Dict[(level, name), cte]
1335 # Level is required for supporting nesting
1336 ctes_by_level_name: Dict[Tuple[int, str], CTE]
1337
1338 # To retrieve key/level in ctes_by_level_name -
1339 # Dict[cte_reference, (level, cte_name, cte_opts)]
1340 level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
1341
1342 ctes_recursive: bool
1343
1344 _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]")
1345 _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s")
1346 _positional_pattern = re.compile(
1347 f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
1348 )
1349
1350 @classmethod
1351 def _init_compiler_cls(cls):
1352 cls._init_bind_translate()
1353
1354 @classmethod
1355 def _init_bind_translate(cls):
1356 reg = re.escape("".join(cls.bindname_escape_characters))
1357 cls._bind_translate_re = re.compile(f"[{reg}]")
1358 cls._bind_translate_chars = cls.bindname_escape_characters
1359
1360 def __init__(
1361 self,
1362 dialect: Dialect,
1363 statement: Optional[ClauseElement],
1364 cache_key: Optional[CacheKey] = None,
1365 column_keys: Optional[Sequence[str]] = None,
1366 for_executemany: bool = False,
1367 linting: Linting = NO_LINTING,
1368 _supporting_against: Optional[SQLCompiler] = None,
1369 **kwargs: Any,
1370 ):
1371 """Construct a new :class:`.SQLCompiler` object.
1372
1373 :param dialect: :class:`.Dialect` to be used
1374
1375 :param statement: :class:`_expression.ClauseElement` to be compiled
1376
1377 :param column_keys: a list of column names to be compiled into an
1378 INSERT or UPDATE statement.
1379
1380 :param for_executemany: whether INSERT / UPDATE statements should
1381 expect that they are to be invoked in an "executemany" style,
1382 which may impact how the statement will be expected to return the
1383 values of defaults and autoincrement / sequences and similar.
1384 Depending on the backend and driver in use, support for retrieving
1385 these values may be disabled which means SQL expressions may
1386 be rendered inline, RETURNING may not be rendered, etc.
1387
1388 :param kwargs: additional keyword arguments to be consumed by the
1389 superclass.
1390
1391 """
1392 self.column_keys = column_keys
1393
1394 self.cache_key = cache_key
1395
1396 if cache_key:
1397 cksm = {b.key: b for b in cache_key[1]}
1398 ckbm = {b: [b] for b in cache_key[1]}
1399 self._cache_key_bind_match = (ckbm, cksm)
1400
1401 # compile INSERT/UPDATE defaults/sequences to expect executemany
1402 # style execution, which may mean no pre-execute of defaults,
1403 # or no RETURNING
1404 self.for_executemany = for_executemany
1405
1406 self.linting = linting
1407
1408 # a dictionary of bind parameter keys to BindParameter
1409 # instances.
1410 self.binds = {}
1411
1412 # a dictionary of BindParameter instances to "compiled" names
1413 # that are actually present in the generated SQL
1414 self.bind_names = util.column_dict()
1415
1416 # stack which keeps track of nested SELECT statements
1417 self.stack = []
1418
1419 self._result_columns = []
1420
1421 # true if the paramstyle is positional
1422 self.positional = dialect.positional
1423 if self.positional:
1424 self._numeric_binds = nb = dialect.paramstyle.startswith("numeric")
1425 if nb:
1426 self._numeric_binds_identifier_char = (
1427 "$" if dialect.paramstyle == "numeric_dollar" else ":"
1428 )
1429
1430 self.compilation_bindtemplate = _pyformat_template
1431 else:
1432 self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
1433
1434 self.ctes = None
1435
1436 self.label_length = (
1437 dialect.label_length or dialect.max_identifier_length
1438 )
1439
1440 # a map which tracks "anonymous" identifiers that are created on
1441 # the fly here
1442 self.anon_map = prefix_anon_map()
1443
1444 # a map which tracks "truncated" names based on
1445 # dialect.label_length or dialect.max_identifier_length
1446 self.truncated_names: Dict[Tuple[str, str], str] = {}
1447 self._truncated_counters: Dict[str, int] = {}
1448
1449 Compiled.__init__(self, dialect, statement, **kwargs)
1450
1451 if self.isinsert or self.isupdate or self.isdelete:
1452 if TYPE_CHECKING:
1453 assert isinstance(statement, UpdateBase)
1454
1455 if self.isinsert or self.isupdate:
1456 if TYPE_CHECKING:
1457 assert isinstance(statement, ValuesBase)
1458 if statement._inline:
1459 self.inline = True
1460 elif self.for_executemany and (
1461 not self.isinsert
1462 or (
1463 self.dialect.insert_executemany_returning
1464 and statement._return_defaults
1465 )
1466 ):
1467 self.inline = True
1468
1469 self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
1470
1471 if _supporting_against:
1472 self.__dict__.update(
1473 {
1474 k: v
1475 for k, v in _supporting_against.__dict__.items()
1476 if k
1477 not in {
1478 "state",
1479 "dialect",
1480 "preparer",
1481 "positional",
1482 "_numeric_binds",
1483 "compilation_bindtemplate",
1484 "bindtemplate",
1485 }
1486 }
1487 )
1488
1489 if self.state is CompilerState.STRING_APPLIED:
1490 if self.positional:
1491 if self._numeric_binds:
1492 self._process_numeric()
1493 else:
1494 self._process_positional()
1495
1496 if self._render_postcompile:
1497 parameters = self.construct_params(
1498 escape_names=False,
1499 _no_postcompile=True,
1500 )
1501
1502 self._process_parameters_for_postcompile(
1503 parameters, _populate_self=True
1504 )
1505
1506 @property
1507 def insert_single_values_expr(self) -> Optional[str]:
1508 """When an INSERT is compiled with a single set of parameters inside
1509 a VALUES expression, the string is assigned here, where it can be
1510 used for insert batching schemes to rewrite the VALUES expression.
1511
1512 .. versionchanged:: 2.0 This collection is no longer used by
1513 SQLAlchemy's built-in dialects, in favor of the currently
1514 internal ``_insertmanyvalues`` collection that is used only by
1515 :class:`.SQLCompiler`.
1516
1517 """
1518 if self._insertmanyvalues is None:
1519 return None
1520 else:
1521 return self._insertmanyvalues.single_values_expr
1522
1523 @util.ro_memoized_property
1524 def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]:
1525 """The effective "returning" columns for INSERT, UPDATE or DELETE.
1526
1527 This is either the so-called "implicit returning" columns which are
1528 calculated by the compiler on the fly, or those present based on what's
1529 present in ``self.statement._returning`` (expanded into individual
1530 columns using the ``._all_selected_columns`` attribute) i.e. those set
1531 explicitly using the :meth:`.UpdateBase.returning` method.
1532
1533 .. versionadded:: 2.0
1534
1535 """
1536 if self.implicit_returning:
1537 return self.implicit_returning
1538 elif self.statement is not None and is_dml(self.statement):
1539 return [
1540 c
1541 for c in self.statement._all_selected_columns
1542 if is_column_element(c)
1543 ]
1544
1545 else:
1546 return None
1547
1548 @property
1549 def returning(self):
1550 """backwards compatibility; returns the
1551 effective_returning collection.
1552
1553 """
1554 return self.effective_returning
1555
1556 @property
1557 def current_executable(self):
1558 """Return the current 'executable' that is being compiled.
1559
1560 This is currently the :class:`_sql.Select`, :class:`_sql.Insert`,
1561 :class:`_sql.Update`, :class:`_sql.Delete`,
1562 :class:`_sql.CompoundSelect` object that is being compiled.
1563 Specifically it's assigned to the ``self.stack`` list of elements.
1564
1565 When a statement like the above is being compiled, it normally
1566 is also assigned to the ``.statement`` attribute of the
1567 :class:`_sql.Compiler` object. However, all SQL constructs are
1568 ultimately nestable, and this attribute should never be consulted
1569 by a ``visit_`` method, as it is not guaranteed to be assigned
1570 nor guaranteed to correspond to the current statement being compiled.
1571
1572 """
1573 try:
1574 return self.stack[-1]["selectable"]
1575 except IndexError as ie:
1576 raise IndexError("Compiler does not have a stack entry") from ie
1577
1578 @property
1579 def prefetch(self):
1580 return list(self.insert_prefetch) + list(self.update_prefetch)
1581
1582 @util.memoized_property
1583 def _global_attributes(self) -> Dict[Any, Any]:
1584 return {}
1585
1586 @util.memoized_instancemethod
1587 def _init_cte_state(self) -> MutableMapping[CTE, str]:
1588 """Initialize collections related to CTEs only if
1589 a CTE is located, to save on the overhead of
1590 these collections otherwise.
1591
1592 """
1593 # collect CTEs to tack on top of a SELECT
1594 # To store the query to print - Dict[cte, text_query]
1595 ctes: MutableMapping[CTE, str] = util.OrderedDict()
1596 self.ctes = ctes
1597
1598 # Detect same CTE references - Dict[(level, name), cte]
1599 # Level is required for supporting nesting
1600 self.ctes_by_level_name = {}
1601
1602 # To retrieve key/level in ctes_by_level_name -
1603 # Dict[cte_reference, (level, cte_name, cte_opts)]
1604 self.level_name_by_cte = {}
1605
1606 self.ctes_recursive = False
1607
1608 return ctes
1609
1610 @contextlib.contextmanager
1611 def _nested_result(self):
1612 """special API to support the use case of 'nested result sets'"""
1613 result_columns, ordered_columns = (
1614 self._result_columns,
1615 self._ordered_columns,
1616 )
1617 self._result_columns, self._ordered_columns = [], False
1618
1619 try:
1620 if self.stack:
1621 entry = self.stack[-1]
1622 entry["need_result_map_for_nested"] = True
1623 else:
1624 entry = None
1625 yield self._result_columns, self._ordered_columns
1626 finally:
1627 if entry:
1628 entry.pop("need_result_map_for_nested")
1629 self._result_columns, self._ordered_columns = (
1630 result_columns,
1631 ordered_columns,
1632 )
1633
1634 def _process_positional(self):
1635 assert not self.positiontup
1636 assert self.state is CompilerState.STRING_APPLIED
1637 assert not self._numeric_binds
1638
1639 if self.dialect.paramstyle == "format":
1640 placeholder = "%s"
1641 else:
1642 assert self.dialect.paramstyle == "qmark"
1643 placeholder = "?"
1644
1645 positions = []
1646
1647 def find_position(m: re.Match[str]) -> str:
1648 normal_bind = m.group(1)
1649 if normal_bind:
1650 positions.append(normal_bind)
1651 return placeholder
1652 else:
1653 # this a post-compile bind
1654 positions.append(m.group(2))
1655 return m.group(0)
1656
1657 self.string = re.sub(
1658 self._positional_pattern, find_position, self.string
1659 )
1660
1661 if self.escaped_bind_names:
1662 reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
1663 assert len(self.escaped_bind_names) == len(reverse_escape)
1664 self.positiontup = [
1665 reverse_escape.get(name, name) for name in positions
1666 ]
1667 else:
1668 self.positiontup = positions
1669
1670 if self._insertmanyvalues:
1671 positions = []
1672
1673 single_values_expr = re.sub(
1674 self._positional_pattern,
1675 find_position,
1676 self._insertmanyvalues.single_values_expr,
1677 )
1678 insert_crud_params = [
1679 (
1680 v[0],
1681 v[1],
1682 re.sub(self._positional_pattern, find_position, v[2]),
1683 v[3],
1684 )
1685 for v in self._insertmanyvalues.insert_crud_params
1686 ]
1687
1688 self._insertmanyvalues = self._insertmanyvalues._replace(
1689 single_values_expr=single_values_expr,
1690 insert_crud_params=insert_crud_params,
1691 )
1692
1693 def _process_numeric(self):
1694 assert self._numeric_binds
1695 assert self.state is CompilerState.STRING_APPLIED
1696
1697 num = 1
1698 param_pos: Dict[str, str] = {}
1699 order: Iterable[str]
1700 if self._insertmanyvalues and self._values_bindparam is not None:
1701 # bindparams that are not in values are always placed first.
1702 # this avoids the need of changing them when using executemany
1703 # values () ()
1704 order = itertools.chain(
1705 (
1706 name
1707 for name in self.bind_names.values()
1708 if name not in self._values_bindparam
1709 ),
1710 self.bind_names.values(),
1711 )
1712 else:
1713 order = self.bind_names.values()
1714
1715 for bind_name in order:
1716 if bind_name in param_pos:
1717 continue
1718 bind = self.binds[bind_name]
1719 if (
1720 bind in self.post_compile_params
1721 or bind in self.literal_execute_params
1722 ):
1723 # set to None to just mark the in positiontup, it will not
1724 # be replaced below.
1725 param_pos[bind_name] = None # type: ignore
1726 else:
1727 ph = f"{self._numeric_binds_identifier_char}{num}"
1728 num += 1
1729 param_pos[bind_name] = ph
1730
1731 self.next_numeric_pos = num
1732
1733 self.positiontup = list(param_pos)
1734 if self.escaped_bind_names:
1735 len_before = len(param_pos)
1736 param_pos = {
1737 self.escaped_bind_names.get(name, name): pos
1738 for name, pos in param_pos.items()
1739 }
1740 assert len(param_pos) == len_before
1741
1742 # Can't use format here since % chars are not escaped.
1743 self.string = self._pyformat_pattern.sub(
1744 lambda m: param_pos[m.group(1)], self.string
1745 )
1746
1747 if self._insertmanyvalues:
1748 single_values_expr = (
1749 # format is ok here since single_values_expr includes only
1750 # place-holders
1751 self._insertmanyvalues.single_values_expr
1752 % param_pos
1753 )
1754 insert_crud_params = [
1755 (v[0], v[1], "%s", v[3])
1756 for v in self._insertmanyvalues.insert_crud_params
1757 ]
1758
1759 self._insertmanyvalues = self._insertmanyvalues._replace(
1760 # This has the numbers (:1, :2)
1761 single_values_expr=single_values_expr,
1762 # The single binds are instead %s so they can be formatted
1763 insert_crud_params=insert_crud_params,
1764 )
1765
1766 @util.memoized_property
1767 def _bind_processors(
1768 self,
1769 ) -> MutableMapping[
1770 str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]]
1771 ]:
1772 # mypy is not able to see the two value types as the above Union,
1773 # it just sees "object". don't know how to resolve
1774 return {
1775 key: value # type: ignore
1776 for key, value in (
1777 (
1778 self.bind_names[bindparam],
1779 (
1780 bindparam.type._cached_bind_processor(self.dialect)
1781 if not bindparam.type._is_tuple_type
1782 else tuple(
1783 elem_type._cached_bind_processor(self.dialect)
1784 for elem_type in cast(
1785 TupleType, bindparam.type
1786 ).types
1787 )
1788 ),
1789 )
1790 for bindparam in self.bind_names
1791 )
1792 if value is not None
1793 }
1794
1795 def is_subquery(self):
1796 return len(self.stack) > 1
1797
1798 @property
1799 def sql_compiler(self) -> Self:
1800 return self
1801
1802 def construct_expanded_state(
1803 self,
1804 params: Optional[_CoreSingleExecuteParams] = None,
1805 escape_names: bool = True,
1806 ) -> ExpandedState:
1807 """Return a new :class:`.ExpandedState` for a given parameter set.
1808
1809 For queries that use "expanding" or other late-rendered parameters,
1810 this method will provide for both the finalized SQL string as well
1811 as the parameters that would be used for a particular parameter set.
1812
1813 .. versionadded:: 2.0.0rc1
1814
1815 """
1816 parameters = self.construct_params(
1817 params,
1818 escape_names=escape_names,
1819 _no_postcompile=True,
1820 )
1821 return self._process_parameters_for_postcompile(
1822 parameters,
1823 )
1824
1825 def construct_params(
1826 self,
1827 params: Optional[_CoreSingleExecuteParams] = None,
1828 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
1829 escape_names: bool = True,
1830 _group_number: Optional[int] = None,
1831 _check: bool = True,
1832 _no_postcompile: bool = False,
1833 ) -> _MutableCoreSingleExecuteParams:
1834 """return a dictionary of bind parameter keys and values"""
1835
1836 if self._render_postcompile and not _no_postcompile:
1837 assert self._post_compile_expanded_state is not None
1838 if not params:
1839 return dict(self._post_compile_expanded_state.parameters)
1840 else:
1841 raise exc.InvalidRequestError(
1842 "can't construct new parameters when render_postcompile "
1843 "is used; the statement is hard-linked to the original "
1844 "parameters. Use construct_expanded_state to generate a "
1845 "new statement and parameters."
1846 )
1847
1848 has_escaped_names = escape_names and bool(self.escaped_bind_names)
1849
1850 if extracted_parameters:
1851 # related the bound parameters collected in the original cache key
1852 # to those collected in the incoming cache key. They will not have
1853 # matching names but they will line up positionally in the same
1854 # way. The parameters present in self.bind_names may be clones of
1855 # these original cache key params in the case of DML but the .key
1856 # will be guaranteed to match.
1857 if self.cache_key is None:
1858 raise exc.CompileError(
1859 "This compiled object has no original cache key; "
1860 "can't pass extracted_parameters to construct_params"
1861 )
1862 else:
1863 orig_extracted = self.cache_key[1]
1864
1865 ckbm_tuple = self._cache_key_bind_match
1866 assert ckbm_tuple is not None
1867 ckbm, _ = ckbm_tuple
1868 resolved_extracted = {
1869 bind: extracted
1870 for b, extracted in zip(orig_extracted, extracted_parameters)
1871 for bind in ckbm[b]
1872 }
1873 else:
1874 resolved_extracted = None
1875
1876 if params:
1877 pd = {}
1878 for bindparam, name in self.bind_names.items():
1879 escaped_name = (
1880 self.escaped_bind_names.get(name, name)
1881 if has_escaped_names
1882 else name
1883 )
1884
1885 if bindparam.key in params:
1886 pd[escaped_name] = params[bindparam.key]
1887 elif name in params:
1888 pd[escaped_name] = params[name]
1889
1890 elif _check and bindparam.required:
1891 if _group_number:
1892 raise exc.InvalidRequestError(
1893 "A value is required for bind parameter %r, "
1894 "in parameter group %d"
1895 % (bindparam.key, _group_number),
1896 code="cd3x",
1897 )
1898 else:
1899 raise exc.InvalidRequestError(
1900 "A value is required for bind parameter %r"
1901 % bindparam.key,
1902 code="cd3x",
1903 )
1904 else:
1905 if resolved_extracted:
1906 value_param = resolved_extracted.get(
1907 bindparam, bindparam
1908 )
1909 else:
1910 value_param = bindparam
1911
1912 if bindparam.callable:
1913 pd[escaped_name] = value_param.effective_value
1914 else:
1915 pd[escaped_name] = value_param.value
1916 return pd
1917 else:
1918 pd = {}
1919 for bindparam, name in self.bind_names.items():
1920 escaped_name = (
1921 self.escaped_bind_names.get(name, name)
1922 if has_escaped_names
1923 else name
1924 )
1925
1926 if _check and bindparam.required:
1927 if _group_number:
1928 raise exc.InvalidRequestError(
1929 "A value is required for bind parameter %r, "
1930 "in parameter group %d"
1931 % (bindparam.key, _group_number),
1932 code="cd3x",
1933 )
1934 else:
1935 raise exc.InvalidRequestError(
1936 "A value is required for bind parameter %r"
1937 % bindparam.key,
1938 code="cd3x",
1939 )
1940
1941 if resolved_extracted:
1942 value_param = resolved_extracted.get(bindparam, bindparam)
1943 else:
1944 value_param = bindparam
1945
1946 if bindparam.callable:
1947 pd[escaped_name] = value_param.effective_value
1948 else:
1949 pd[escaped_name] = value_param.value
1950
1951 return pd
1952
1953 @util.memoized_instancemethod
1954 def _get_set_input_sizes_lookup(self):
1955 dialect = self.dialect
1956
1957 include_types = dialect.include_set_input_sizes
1958 exclude_types = dialect.exclude_set_input_sizes
1959
1960 dbapi = dialect.dbapi
1961
1962 def lookup_type(typ):
1963 dbtype = typ._unwrapped_dialect_impl(dialect).get_dbapi_type(dbapi)
1964
1965 if (
1966 dbtype is not None
1967 and (exclude_types is None or dbtype not in exclude_types)
1968 and (include_types is None or dbtype in include_types)
1969 ):
1970 return dbtype
1971 else:
1972 return None
1973
1974 inputsizes = {}
1975
1976 literal_execute_params = self.literal_execute_params
1977
1978 for bindparam in self.bind_names:
1979 if bindparam in literal_execute_params:
1980 continue
1981
1982 if bindparam.type._is_tuple_type:
1983 inputsizes[bindparam] = [
1984 lookup_type(typ)
1985 for typ in cast(TupleType, bindparam.type).types
1986 ]
1987 else:
1988 inputsizes[bindparam] = lookup_type(bindparam.type)
1989
1990 return inputsizes
1991
1992 @property
1993 def params(self):
1994 """Return the bind param dictionary embedded into this
1995 compiled object, for those values that are present.
1996
1997 .. seealso::
1998
1999 :ref:`faq_sql_expression_string` - includes a usage example for
2000 debugging use cases.
2001
2002 """
2003 return self.construct_params(_check=False)
2004
2005 def _process_parameters_for_postcompile(
2006 self,
2007 parameters: _MutableCoreSingleExecuteParams,
2008 _populate_self: bool = False,
2009 ) -> ExpandedState:
2010 """handle special post compile parameters.
2011
2012 These include:
2013
2014 * "expanding" parameters -typically IN tuples that are rendered
2015 on a per-parameter basis for an otherwise fixed SQL statement string.
2016
2017 * literal_binds compiled with the literal_execute flag. Used for
2018 things like SQL Server "TOP N" where the driver does not accommodate
2019 N as a bound parameter.
2020
2021 """
2022
2023 expanded_parameters = {}
2024 new_positiontup: Optional[List[str]]
2025
2026 pre_expanded_string = self._pre_expanded_string
2027 if pre_expanded_string is None:
2028 pre_expanded_string = self.string
2029
2030 if self.positional:
2031 new_positiontup = []
2032
2033 pre_expanded_positiontup = self._pre_expanded_positiontup
2034 if pre_expanded_positiontup is None:
2035 pre_expanded_positiontup = self.positiontup
2036
2037 else:
2038 new_positiontup = pre_expanded_positiontup = None
2039
2040 processors = self._bind_processors
2041 single_processors = cast(
2042 "Mapping[str, _BindProcessorType[Any]]", processors
2043 )
2044 tuple_processors = cast(
2045 "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors
2046 )
2047
2048 new_processors: Dict[str, _BindProcessorType[Any]] = {}
2049
2050 replacement_expressions: Dict[str, Any] = {}
2051 to_update_sets: Dict[str, Any] = {}
2052
2053 # notes:
2054 # *unescaped* parameter names in:
2055 # self.bind_names, self.binds, self._bind_processors, self.positiontup
2056 #
2057 # *escaped* parameter names in:
2058 # construct_params(), replacement_expressions
2059
2060 numeric_positiontup: Optional[List[str]] = None
2061
2062 if self.positional and pre_expanded_positiontup is not None:
2063 names: Iterable[str] = pre_expanded_positiontup
2064 if self._numeric_binds:
2065 numeric_positiontup = []
2066 else:
2067 names = self.bind_names.values()
2068
2069 ebn = self.escaped_bind_names
2070 for name in names:
2071 escaped_name = ebn.get(name, name) if ebn else name
2072 parameter = self.binds[name]
2073
2074 if parameter in self.literal_execute_params:
2075 if escaped_name not in replacement_expressions:
2076 replacement_expressions[escaped_name] = (
2077 self.render_literal_bindparam(
2078 parameter,
2079 render_literal_value=parameters.pop(escaped_name),
2080 )
2081 )
2082 continue
2083
2084 if parameter in self.post_compile_params:
2085 if escaped_name in replacement_expressions:
2086 to_update = to_update_sets[escaped_name]
2087 values = None
2088 else:
2089 # we are removing the parameter from parameters
2090 # because it is a list value, which is not expected by
2091 # TypeEngine objects that would otherwise be asked to
2092 # process it. the single name is being replaced with
2093 # individual numbered parameters for each value in the
2094 # param.
2095 #
2096 # note we are also inserting *escaped* parameter names
2097 # into the given dictionary. default dialect will
2098 # use these param names directly as they will not be
2099 # in the escaped_bind_names dictionary.
2100 values = parameters.pop(name)
2101
2102 leep_res = self._literal_execute_expanding_parameter(
2103 escaped_name, parameter, values
2104 )
2105 (to_update, replacement_expr) = leep_res
2106
2107 to_update_sets[escaped_name] = to_update
2108 replacement_expressions[escaped_name] = replacement_expr
2109
2110 if not parameter.literal_execute:
2111 parameters.update(to_update)
2112 if parameter.type._is_tuple_type:
2113 assert values is not None
2114 new_processors.update(
2115 (
2116 "%s_%s_%s" % (name, i, j),
2117 tuple_processors[name][j - 1],
2118 )
2119 for i, tuple_element in enumerate(values, 1)
2120 for j, _ in enumerate(tuple_element, 1)
2121 if name in tuple_processors
2122 and tuple_processors[name][j - 1] is not None
2123 )
2124 else:
2125 new_processors.update(
2126 (key, single_processors[name])
2127 for key, _ in to_update
2128 if name in single_processors
2129 )
2130 if numeric_positiontup is not None:
2131 numeric_positiontup.extend(
2132 name for name, _ in to_update
2133 )
2134 elif new_positiontup is not None:
2135 # to_update has escaped names, but that's ok since
2136 # these are new names, that aren't in the
2137 # escaped_bind_names dict.
2138 new_positiontup.extend(name for name, _ in to_update)
2139 expanded_parameters[name] = [
2140 expand_key for expand_key, _ in to_update
2141 ]
2142 elif new_positiontup is not None:
2143 new_positiontup.append(name)
2144
2145 def process_expanding(m):
2146 key = m.group(1)
2147 expr = replacement_expressions[key]
2148
2149 # if POSTCOMPILE included a bind_expression, render that
2150 # around each element
2151 if m.group(2):
2152 tok = m.group(2).split("~~")
2153 be_left, be_right = tok[1], tok[3]
2154 expr = ", ".join(
2155 "%s%s%s" % (be_left, exp, be_right)
2156 for exp in expr.split(", ")
2157 )
2158 return expr
2159
2160 statement = re.sub(
2161 self._post_compile_pattern, process_expanding, pre_expanded_string
2162 )
2163
2164 if numeric_positiontup is not None:
2165 assert new_positiontup is not None
2166 param_pos = {
2167 key: f"{self._numeric_binds_identifier_char}{num}"
2168 for num, key in enumerate(
2169 numeric_positiontup, self.next_numeric_pos
2170 )
2171 }
2172 # Can't use format here since % chars are not escaped.
2173 statement = self._pyformat_pattern.sub(
2174 lambda m: param_pos[m.group(1)], statement
2175 )
2176 new_positiontup.extend(numeric_positiontup)
2177
2178 expanded_state = ExpandedState(
2179 statement,
2180 parameters,
2181 new_processors,
2182 new_positiontup,
2183 expanded_parameters,
2184 )
2185
2186 if _populate_self:
2187 # this is for the "render_postcompile" flag, which is not
2188 # otherwise used internally and is for end-user debugging and
2189 # special use cases.
2190 self._pre_expanded_string = pre_expanded_string
2191 self._pre_expanded_positiontup = pre_expanded_positiontup
2192 self.string = expanded_state.statement
2193 self.positiontup = (
2194 list(expanded_state.positiontup or ())
2195 if self.positional
2196 else None
2197 )
2198 self._post_compile_expanded_state = expanded_state
2199
2200 return expanded_state
2201
2202 @util.preload_module("sqlalchemy.engine.cursor")
2203 def _create_result_map(self):
2204 """utility method used for unit tests only."""
2205 cursor = util.preloaded.engine_cursor
2206 return cursor.CursorResultMetaData._create_description_match_map(
2207 self._result_columns
2208 )
2209
2210 # assigned by crud.py for insert/update statements
2211 _get_bind_name_for_col: _BindNameForColProtocol
2212
2213 @util.memoized_property
2214 def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
2215 getter = self._get_bind_name_for_col
2216 return getter
2217
2218 @util.memoized_property
2219 @util.preload_module("sqlalchemy.engine.result")
2220 def _inserted_primary_key_from_lastrowid_getter(self):
2221 result = util.preloaded.engine_result
2222
2223 param_key_getter = self._within_exec_param_key_getter
2224
2225 assert self.compile_state is not None
2226 statement = self.compile_state.statement
2227
2228 if TYPE_CHECKING:
2229 assert isinstance(statement, Insert)
2230
2231 table = statement.table
2232
2233 getters = [
2234 (operator.methodcaller("get", param_key_getter(col), None), col)
2235 for col in table.primary_key
2236 ]
2237
2238 autoinc_getter = None
2239 autoinc_col = table._autoincrement_column
2240 if autoinc_col is not None:
2241 # apply type post processors to the lastrowid
2242 lastrowid_processor = autoinc_col.type._cached_result_processor(
2243 self.dialect, None
2244 )
2245 autoinc_key = param_key_getter(autoinc_col)
2246
2247 # if a bind value is present for the autoincrement column
2248 # in the parameters, we need to do the logic dictated by
2249 # #7998; honor a non-None user-passed parameter over lastrowid.
2250 # previously in the 1.4 series we weren't fetching lastrowid
2251 # at all if the key were present in the parameters
2252 if autoinc_key in self.binds:
2253
2254 def _autoinc_getter(lastrowid, parameters):
2255 param_value = parameters.get(autoinc_key, lastrowid)
2256 if param_value is not None:
2257 # they supplied non-None parameter, use that.
2258 # SQLite at least is observed to return the wrong
2259 # cursor.lastrowid for INSERT..ON CONFLICT so it
2260 # can't be used in all cases
2261 return param_value
2262 else:
2263 # use lastrowid
2264 return lastrowid
2265
2266 # work around mypy https://github.com/python/mypy/issues/14027
2267 autoinc_getter = _autoinc_getter
2268
2269 else:
2270 lastrowid_processor = None
2271
2272 row_fn = result.result_tuple([col.key for col in table.primary_key])
2273
2274 def get(lastrowid, parameters):
2275 """given cursor.lastrowid value and the parameters used for INSERT,
2276 return a "row" that represents the primary key, either by
2277 using the "lastrowid" or by extracting values from the parameters
2278 that were sent along with the INSERT.
2279
2280 """
2281 if lastrowid_processor is not None:
2282 lastrowid = lastrowid_processor(lastrowid)
2283
2284 if lastrowid is None:
2285 return row_fn(getter(parameters) for getter, col in getters)
2286 else:
2287 return row_fn(
2288 (
2289 (
2290 autoinc_getter(lastrowid, parameters)
2291 if autoinc_getter is not None
2292 else lastrowid
2293 )
2294 if col is autoinc_col
2295 else getter(parameters)
2296 )
2297 for getter, col in getters
2298 )
2299
2300 return get
2301
2302 @util.memoized_property
2303 @util.preload_module("sqlalchemy.engine.result")
2304 def _inserted_primary_key_from_returning_getter(self):
2305 result = util.preloaded.engine_result
2306
2307 assert self.compile_state is not None
2308 statement = self.compile_state.statement
2309
2310 if TYPE_CHECKING:
2311 assert isinstance(statement, Insert)
2312
2313 param_key_getter = self._within_exec_param_key_getter
2314 table = statement.table
2315
2316 returning = self.implicit_returning
2317 assert returning is not None
2318 ret = {col: idx for idx, col in enumerate(returning)}
2319
2320 getters = cast(
2321 "List[Tuple[Callable[[Any], Any], bool]]",
2322 [
2323 (
2324 (operator.itemgetter(ret[col]), True)
2325 if col in ret
2326 else (
2327 operator.methodcaller(
2328 "get", param_key_getter(col), None
2329 ),
2330 False,
2331 )
2332 )
2333 for col in table.primary_key
2334 ],
2335 )
2336
2337 row_fn = result.result_tuple([col.key for col in table.primary_key])
2338
2339 def get(row, parameters):
2340 return row_fn(
2341 getter(row) if use_row else getter(parameters)
2342 for getter, use_row in getters
2343 )
2344
2345 return get
2346
2347 def default_from(self) -> str:
2348 """Called when a SELECT statement has no froms, and no FROM clause is
2349 to be appended.
2350
2351 Gives Oracle Database a chance to tack on a ``FROM DUAL`` to the string
2352 output.
2353
2354 """
2355 return ""
2356
2357 def visit_override_binds(self, override_binds, **kw):
2358 """SQL compile the nested element of an _OverrideBinds with
2359 bindparams swapped out.
2360
2361 The _OverrideBinds is not normally expected to be compiled; it
2362 is meant to be used when an already cached statement is to be used,
2363 the compilation was already performed, and only the bound params should
2364 be swapped in at execution time.
2365
2366 However, there are test cases that exericise this object, and
2367 additionally the ORM subquery loader is known to feed in expressions
2368 which include this construct into new queries (discovered in #11173),
2369 so it has to do the right thing at compile time as well.
2370
2371 """
2372
2373 # get SQL text first
2374 sqltext = override_binds.element._compiler_dispatch(self, **kw)
2375
2376 # for a test compile that is not for caching, change binds after the
2377 # fact. note that we don't try to
2378 # swap the bindparam as we compile, because our element may be
2379 # elsewhere in the statement already (e.g. a subquery or perhaps a
2380 # CTE) and was already visited / compiled. See
2381 # test_relationship_criteria.py ->
2382 # test_selectinload_local_criteria_subquery
2383 for k in override_binds.translate:
2384 if k not in self.binds:
2385 continue
2386 bp = self.binds[k]
2387
2388 # so this would work, just change the value of bp in place.
2389 # but we dont want to mutate things outside.
2390 # bp.value = override_binds.translate[bp.key]
2391 # continue
2392
2393 # instead, need to replace bp with new_bp or otherwise accommodate
2394 # in all internal collections
2395 new_bp = bp._with_value(
2396 override_binds.translate[bp.key],
2397 maintain_key=True,
2398 required=False,
2399 )
2400
2401 name = self.bind_names[bp]
2402 self.binds[k] = self.binds[name] = new_bp
2403 self.bind_names[new_bp] = name
2404 self.bind_names.pop(bp, None)
2405
2406 if bp in self.post_compile_params:
2407 self.post_compile_params |= {new_bp}
2408 if bp in self.literal_execute_params:
2409 self.literal_execute_params |= {new_bp}
2410
2411 ckbm_tuple = self._cache_key_bind_match
2412 if ckbm_tuple:
2413 ckbm, cksm = ckbm_tuple
2414 for bp in bp._cloned_set:
2415 if bp.key in cksm:
2416 cb = cksm[bp.key]
2417 ckbm[cb].append(new_bp)
2418
2419 return sqltext
2420
2421 def visit_grouping(self, grouping, asfrom=False, **kwargs):
2422 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
2423
2424 def visit_select_statement_grouping(self, grouping, **kwargs):
2425 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
2426
2427 def visit_label_reference(
2428 self, element, within_columns_clause=False, **kwargs
2429 ):
2430 if self.stack and self.dialect.supports_simple_order_by_label:
2431 try:
2432 compile_state = cast(
2433 "Union[SelectState, CompoundSelectState]",
2434 self.stack[-1]["compile_state"],
2435 )
2436 except KeyError as ke:
2437 raise exc.CompileError(
2438 "Can't resolve label reference for ORDER BY / "
2439 "GROUP BY / DISTINCT etc."
2440 ) from ke
2441
2442 (
2443 with_cols,
2444 only_froms,
2445 only_cols,
2446 ) = compile_state._label_resolve_dict
2447 if within_columns_clause:
2448 resolve_dict = only_froms
2449 else:
2450 resolve_dict = only_cols
2451
2452 # this can be None in the case that a _label_reference()
2453 # were subject to a replacement operation, in which case
2454 # the replacement of the Label element may have changed
2455 # to something else like a ColumnClause expression.
2456 order_by_elem = element.element._order_by_label_element
2457
2458 if (
2459 order_by_elem is not None
2460 and order_by_elem.name in resolve_dict
2461 and order_by_elem.shares_lineage(
2462 resolve_dict[order_by_elem.name]
2463 )
2464 ):
2465 kwargs["render_label_as_label"] = (
2466 element.element._order_by_label_element
2467 )
2468 return self.process(
2469 element.element,
2470 within_columns_clause=within_columns_clause,
2471 **kwargs,
2472 )
2473
2474 def visit_textual_label_reference(
2475 self, element, within_columns_clause=False, **kwargs
2476 ):
2477 if not self.stack:
2478 # compiling the element outside of the context of a SELECT
2479 return self.process(element._text_clause)
2480
2481 try:
2482 compile_state = cast(
2483 "Union[SelectState, CompoundSelectState]",
2484 self.stack[-1]["compile_state"],
2485 )
2486 except KeyError as ke:
2487 coercions._no_text_coercion(
2488 element.element,
2489 extra=(
2490 "Can't resolve label reference for ORDER BY / "
2491 "GROUP BY / DISTINCT etc."
2492 ),
2493 exc_cls=exc.CompileError,
2494 err=ke,
2495 )
2496
2497 with_cols, only_froms, only_cols = compile_state._label_resolve_dict
2498 try:
2499 if within_columns_clause:
2500 col = only_froms[element.element]
2501 else:
2502 col = with_cols[element.element]
2503 except KeyError as err:
2504 coercions._no_text_coercion(
2505 element.element,
2506 extra=(
2507 "Can't resolve label reference for ORDER BY / "
2508 "GROUP BY / DISTINCT etc."
2509 ),
2510 exc_cls=exc.CompileError,
2511 err=err,
2512 )
2513 else:
2514 kwargs["render_label_as_label"] = col
2515 return self.process(
2516 col, within_columns_clause=within_columns_clause, **kwargs
2517 )
2518
2519 def visit_label(
2520 self,
2521 label,
2522 add_to_result_map=None,
2523 within_label_clause=False,
2524 within_columns_clause=False,
2525 render_label_as_label=None,
2526 result_map_targets=(),
2527 **kw,
2528 ):
2529 # only render labels within the columns clause
2530 # or ORDER BY clause of a select. dialect-specific compilers
2531 # can modify this behavior.
2532 render_label_with_as = (
2533 within_columns_clause and not within_label_clause
2534 )
2535 render_label_only = render_label_as_label is label
2536
2537 if render_label_only or render_label_with_as:
2538 if isinstance(label.name, elements._truncated_label):
2539 labelname = self._truncated_identifier("colident", label.name)
2540 else:
2541 labelname = label.name
2542
2543 if render_label_with_as:
2544 if add_to_result_map is not None:
2545 add_to_result_map(
2546 labelname,
2547 label.name,
2548 (label, labelname) + label._alt_names + result_map_targets,
2549 label.type,
2550 )
2551 return (
2552 label.element._compiler_dispatch(
2553 self,
2554 within_columns_clause=True,
2555 within_label_clause=True,
2556 **kw,
2557 )
2558 + OPERATORS[operators.as_]
2559 + self.preparer.format_label(label, labelname)
2560 )
2561 elif render_label_only:
2562 return self.preparer.format_label(label, labelname)
2563 else:
2564 return label.element._compiler_dispatch(
2565 self, within_columns_clause=False, **kw
2566 )
2567
2568 def _fallback_column_name(self, column):
2569 raise exc.CompileError(
2570 "Cannot compile Column object until its 'name' is assigned."
2571 )
2572
2573 def visit_lambda_element(self, element, **kw):
2574 sql_element = element._resolved
2575 return self.process(sql_element, **kw)
2576
2577 def visit_column(
2578 self,
2579 column: ColumnClause[Any],
2580 add_to_result_map: Optional[_ResultMapAppender] = None,
2581 include_table: bool = True,
2582 result_map_targets: Tuple[Any, ...] = (),
2583 ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
2584 **kwargs: Any,
2585 ) -> str:
2586 name = orig_name = column.name
2587 if name is None:
2588 name = self._fallback_column_name(column)
2589
2590 is_literal = column.is_literal
2591 if not is_literal and isinstance(name, elements._truncated_label):
2592 name = self._truncated_identifier("colident", name)
2593
2594 if add_to_result_map is not None:
2595 targets = (column, name, column.key) + result_map_targets
2596 if column._tq_label:
2597 targets += (column._tq_label,)
2598
2599 add_to_result_map(name, orig_name, targets, column.type)
2600
2601 if is_literal:
2602 # note we are not currently accommodating for
2603 # literal_column(quoted_name('ident', True)) here
2604 name = self.escape_literal_column(name)
2605 else:
2606 name = self.preparer.quote(name)
2607 table = column.table
2608 if table is None or not include_table or not table.named_with_column:
2609 return name
2610 else:
2611 effective_schema = self.preparer.schema_for_object(table)
2612
2613 if effective_schema:
2614 schema_prefix = (
2615 self.preparer.quote_schema(effective_schema) + "."
2616 )
2617 else:
2618 schema_prefix = ""
2619
2620 if TYPE_CHECKING:
2621 assert isinstance(table, NamedFromClause)
2622 tablename = table.name
2623
2624 if (
2625 not effective_schema
2626 and ambiguous_table_name_map
2627 and tablename in ambiguous_table_name_map
2628 ):
2629 tablename = ambiguous_table_name_map[tablename]
2630
2631 if isinstance(tablename, elements._truncated_label):
2632 tablename = self._truncated_identifier("alias", tablename)
2633
2634 return schema_prefix + self.preparer.quote(tablename) + "." + name
2635
2636 def visit_collation(self, element, **kw):
2637 return self.preparer.format_collation(element.collation)
2638
2639 def visit_fromclause(self, fromclause, **kwargs):
2640 return fromclause.name
2641
2642 def visit_index(self, index, **kwargs):
2643 return index.name
2644
2645 def visit_typeclause(self, typeclause, **kw):
2646 kw["type_expression"] = typeclause
2647 kw["identifier_preparer"] = self.preparer
2648 return self.dialect.type_compiler_instance.process(
2649 typeclause.type, **kw
2650 )
2651
2652 def post_process_text(self, text):
2653 if self.preparer._double_percents:
2654 text = text.replace("%", "%%")
2655 return text
2656
2657 def escape_literal_column(self, text):
2658 if self.preparer._double_percents:
2659 text = text.replace("%", "%%")
2660 return text
2661
2662 def visit_textclause(self, textclause, add_to_result_map=None, **kw):
2663 def do_bindparam(m):
2664 name = m.group(1)
2665 if name in textclause._bindparams:
2666 return self.process(textclause._bindparams[name], **kw)
2667 else:
2668 return self.bindparam_string(name, **kw)
2669
2670 if not self.stack:
2671 self.isplaintext = True
2672
2673 if add_to_result_map:
2674 # text() object is present in the columns clause of a
2675 # select(). Add a no-name entry to the result map so that
2676 # row[text()] produces a result
2677 add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
2678
2679 # un-escape any \:params
2680 return BIND_PARAMS_ESC.sub(
2681 lambda m: m.group(1),
2682 BIND_PARAMS.sub(
2683 do_bindparam, self.post_process_text(textclause.text)
2684 ),
2685 )
2686
2687 def visit_textual_select(
2688 self, taf, compound_index=None, asfrom=False, **kw
2689 ):
2690 toplevel = not self.stack
2691 entry = self._default_stack_entry if toplevel else self.stack[-1]
2692
2693 new_entry: _CompilerStackEntry = {
2694 "correlate_froms": set(),
2695 "asfrom_froms": set(),
2696 "selectable": taf,
2697 }
2698 self.stack.append(new_entry)
2699
2700 if taf._independent_ctes:
2701 self._dispatch_independent_ctes(taf, kw)
2702
2703 populate_result_map = (
2704 toplevel
2705 or (
2706 compound_index == 0
2707 and entry.get("need_result_map_for_compound", False)
2708 )
2709 or entry.get("need_result_map_for_nested", False)
2710 )
2711
2712 if populate_result_map:
2713 self._ordered_columns = self._textual_ordered_columns = (
2714 taf.positional
2715 )
2716
2717 # enable looser result column matching when the SQL text links to
2718 # Column objects by name only
2719 self._loose_column_name_matching = not taf.positional and bool(
2720 taf.column_args
2721 )
2722
2723 for c in taf.column_args:
2724 self.process(
2725 c,
2726 within_columns_clause=True,
2727 add_to_result_map=self._add_to_result_map,
2728 )
2729
2730 text = self.process(taf.element, **kw)
2731 if self.ctes:
2732 nesting_level = len(self.stack) if not toplevel else None
2733 text = self._render_cte_clause(nesting_level=nesting_level) + text
2734
2735 self.stack.pop(-1)
2736
2737 return text
2738
2739 def visit_null(self, expr: Null, **kw: Any) -> str:
2740 return "NULL"
2741
2742 def visit_true(self, expr: True_, **kw: Any) -> str:
2743 if self.dialect.supports_native_boolean:
2744 return "true"
2745 else:
2746 return "1"
2747
2748 def visit_false(self, expr: False_, **kw: Any) -> str:
2749 if self.dialect.supports_native_boolean:
2750 return "false"
2751 else:
2752 return "0"
2753
2754 def _generate_delimited_list(self, elements, separator, **kw):
2755 return separator.join(
2756 s
2757 for s in (c._compiler_dispatch(self, **kw) for c in elements)
2758 if s
2759 )
2760
2761 def _generate_delimited_and_list(self, clauses, **kw):
2762 lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean(
2763 operators.and_,
2764 elements.True_._singleton,
2765 elements.False_._singleton,
2766 clauses,
2767 )
2768 if lcc == 1:
2769 return clauses[0]._compiler_dispatch(self, **kw)
2770 else:
2771 separator = OPERATORS[operators.and_]
2772 return separator.join(
2773 s
2774 for s in (c._compiler_dispatch(self, **kw) for c in clauses)
2775 if s
2776 )
2777
2778 def visit_tuple(self, clauselist, **kw):
2779 return "(%s)" % self.visit_clauselist(clauselist, **kw)
2780
2781 def visit_element_list(self, element, **kw):
2782 return self._generate_delimited_list(element.clauses, " ", **kw)
2783
2784 def visit_clauselist(self, clauselist, **kw):
2785 sep = clauselist.operator
2786 if sep is None:
2787 sep = " "
2788 else:
2789 sep = OPERATORS[clauselist.operator]
2790
2791 return self._generate_delimited_list(clauselist.clauses, sep, **kw)
2792
2793 def visit_expression_clauselist(self, clauselist, **kw):
2794 operator_ = clauselist.operator
2795
2796 disp = self._get_operator_dispatch(
2797 operator_, "expression_clauselist", None
2798 )
2799 if disp:
2800 return disp(clauselist, operator_, **kw)
2801
2802 try:
2803 opstring = OPERATORS[operator_]
2804 except KeyError as err:
2805 raise exc.UnsupportedCompilationError(self, operator_) from err
2806 else:
2807 kw["_in_operator_expression"] = True
2808 return self._generate_delimited_list(
2809 clauselist.clauses, opstring, **kw
2810 )
2811
2812 def visit_case(self, clause, **kwargs):
2813 x = "CASE "
2814 if clause.value is not None:
2815 x += clause.value._compiler_dispatch(self, **kwargs) + " "
2816 for cond, result in clause.whens:
2817 x += (
2818 "WHEN "
2819 + cond._compiler_dispatch(self, **kwargs)
2820 + " THEN "
2821 + result._compiler_dispatch(self, **kwargs)
2822 + " "
2823 )
2824 if clause.else_ is not None:
2825 x += (
2826 "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
2827 )
2828 x += "END"
2829 return x
2830
2831 def visit_type_coerce(self, type_coerce, **kw):
2832 return type_coerce.typed_expression._compiler_dispatch(self, **kw)
2833
2834 def visit_cast(self, cast, **kwargs):
2835 type_clause = cast.typeclause._compiler_dispatch(self, **kwargs)
2836 match = re.match("(.*)( COLLATE .*)", type_clause)
2837 return "CAST(%s AS %s)%s" % (
2838 cast.clause._compiler_dispatch(self, **kwargs),
2839 match.group(1) if match else type_clause,
2840 match.group(2) if match else "",
2841 )
2842
2843 def visit_frame_clause(self, frameclause, **kw):
2844
2845 if frameclause.lower_type is elements._FrameClauseType.RANGE_UNBOUNDED:
2846 left = "UNBOUNDED PRECEDING"
2847 elif frameclause.lower_type is elements._FrameClauseType.RANGE_CURRENT:
2848 left = "CURRENT ROW"
2849 else:
2850 val = self.process(frameclause.lower_integer_bind, **kw)
2851 if (
2852 frameclause.lower_type
2853 is elements._FrameClauseType.RANGE_PRECEDING
2854 ):
2855 left = f"{val} PRECEDING"
2856 else:
2857 left = f"{val} FOLLOWING"
2858
2859 if frameclause.upper_type is elements._FrameClauseType.RANGE_UNBOUNDED:
2860 right = "UNBOUNDED FOLLOWING"
2861 elif frameclause.upper_type is elements._FrameClauseType.RANGE_CURRENT:
2862 right = "CURRENT ROW"
2863 else:
2864 val = self.process(frameclause.upper_integer_bind, **kw)
2865 if (
2866 frameclause.upper_type
2867 is elements._FrameClauseType.RANGE_PRECEDING
2868 ):
2869 right = f"{val} PRECEDING"
2870 else:
2871 right = f"{val} FOLLOWING"
2872
2873 return f"{left} AND {right}"
2874
2875 def visit_over(self, over, **kwargs):
2876 text = over.element._compiler_dispatch(self, **kwargs)
2877 if over.range_ is not None:
2878 range_ = f"RANGE BETWEEN {self.process(over.range_, **kwargs)}"
2879 elif over.rows is not None:
2880 range_ = f"ROWS BETWEEN {self.process(over.rows, **kwargs)}"
2881 elif over.groups is not None:
2882 range_ = f"GROUPS BETWEEN {self.process(over.groups, **kwargs)}"
2883 else:
2884 range_ = None
2885
2886 return "%s OVER (%s)" % (
2887 text,
2888 " ".join(
2889 [
2890 "%s BY %s"
2891 % (word, clause._compiler_dispatch(self, **kwargs))
2892 for word, clause in (
2893 ("PARTITION", over.partition_by),
2894 ("ORDER", over.order_by),
2895 )
2896 if clause is not None and len(clause)
2897 ]
2898 + ([range_] if range_ else [])
2899 ),
2900 )
2901
2902 def visit_withingroup(self, withingroup, **kwargs):
2903 return "%s WITHIN GROUP (ORDER BY %s)" % (
2904 withingroup.element._compiler_dispatch(self, **kwargs),
2905 withingroup.order_by._compiler_dispatch(self, **kwargs),
2906 )
2907
2908 def visit_funcfilter(self, funcfilter, **kwargs):
2909 return "%s FILTER (WHERE %s)" % (
2910 funcfilter.func._compiler_dispatch(self, **kwargs),
2911 funcfilter.criterion._compiler_dispatch(self, **kwargs),
2912 )
2913
2914 def visit_extract(self, extract, **kwargs):
2915 field = self.extract_map.get(extract.field, extract.field)
2916 return "EXTRACT(%s FROM %s)" % (
2917 field,
2918 extract.expr._compiler_dispatch(self, **kwargs),
2919 )
2920
2921 def visit_scalar_function_column(self, element, **kw):
2922 compiled_fn = self.visit_function(element.fn, **kw)
2923 compiled_col = self.visit_column(element, **kw)
2924 return "(%s).%s" % (compiled_fn, compiled_col)
2925
2926 def visit_function(
2927 self,
2928 func: Function[Any],
2929 add_to_result_map: Optional[_ResultMapAppender] = None,
2930 **kwargs: Any,
2931 ) -> str:
2932 if add_to_result_map is not None:
2933 add_to_result_map(func.name, func.name, (func.name,), func.type)
2934
2935 disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
2936
2937 text: str
2938
2939 if disp:
2940 text = disp(func, **kwargs)
2941 else:
2942 name = FUNCTIONS.get(func._deannotate().__class__, None)
2943 if name:
2944 if func._has_args:
2945 name += "%(expr)s"
2946 else:
2947 name = func.name
2948 name = (
2949 self.preparer.quote(name)
2950 if self.preparer._requires_quotes_illegal_chars(name)
2951 or isinstance(name, elements.quoted_name)
2952 else name
2953 )
2954 name = name + "%(expr)s"
2955 text = ".".join(
2956 [
2957 (
2958 self.preparer.quote(tok)
2959 if self.preparer._requires_quotes_illegal_chars(tok)
2960 or isinstance(name, elements.quoted_name)
2961 else tok
2962 )
2963 for tok in func.packagenames
2964 ]
2965 + [name]
2966 ) % {"expr": self.function_argspec(func, **kwargs)}
2967
2968 if func._with_ordinality:
2969 text += " WITH ORDINALITY"
2970 return text
2971
2972 def visit_next_value_func(self, next_value, **kw):
2973 return self.visit_sequence(next_value.sequence)
2974
2975 def visit_sequence(self, sequence, **kw):
2976 raise NotImplementedError(
2977 "Dialect '%s' does not support sequence increments."
2978 % self.dialect.name
2979 )
2980
2981 def function_argspec(self, func: Function[Any], **kwargs: Any) -> str:
2982 return func.clause_expr._compiler_dispatch(self, **kwargs)
2983
2984 def visit_compound_select(
2985 self, cs, asfrom=False, compound_index=None, **kwargs
2986 ):
2987 toplevel = not self.stack
2988
2989 compile_state = cs._compile_state_factory(cs, self, **kwargs)
2990
2991 if toplevel and not self.compile_state:
2992 self.compile_state = compile_state
2993
2994 compound_stmt = compile_state.statement
2995
2996 entry = self._default_stack_entry if toplevel else self.stack[-1]
2997 need_result_map = toplevel or (
2998 not compound_index
2999 and entry.get("need_result_map_for_compound", False)
3000 )
3001
3002 # indicates there is already a CompoundSelect in play
3003 if compound_index == 0:
3004 entry["select_0"] = cs
3005
3006 self.stack.append(
3007 {
3008 "correlate_froms": entry["correlate_froms"],
3009 "asfrom_froms": entry["asfrom_froms"],
3010 "selectable": cs,
3011 "compile_state": compile_state,
3012 "need_result_map_for_compound": need_result_map,
3013 }
3014 )
3015
3016 if compound_stmt._independent_ctes:
3017 self._dispatch_independent_ctes(compound_stmt, kwargs)
3018
3019 keyword = self.compound_keywords[cs.keyword]
3020
3021 text = (" " + keyword + " ").join(
3022 (
3023 c._compiler_dispatch(
3024 self, asfrom=asfrom, compound_index=i, **kwargs
3025 )
3026 for i, c in enumerate(cs.selects)
3027 )
3028 )
3029
3030 kwargs["include_table"] = False
3031 text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
3032 text += self.order_by_clause(cs, **kwargs)
3033 if cs._has_row_limiting_clause:
3034 text += self._row_limit_clause(cs, **kwargs)
3035
3036 if self.ctes:
3037 nesting_level = len(self.stack) if not toplevel else None
3038 text = (
3039 self._render_cte_clause(
3040 nesting_level=nesting_level,
3041 include_following_stack=True,
3042 )
3043 + text
3044 )
3045
3046 self.stack.pop(-1)
3047 return text
3048
3049 def _row_limit_clause(self, cs, **kwargs):
3050 if cs._fetch_clause is not None:
3051 return self.fetch_clause(cs, **kwargs)
3052 else:
3053 return self.limit_clause(cs, **kwargs)
3054
3055 def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
3056 attrname = "visit_%s_%s%s" % (
3057 operator_.__name__,
3058 qualifier1,
3059 "_" + qualifier2 if qualifier2 else "",
3060 )
3061 return getattr(self, attrname, None)
3062
3063 def visit_unary(
3064 self, unary, add_to_result_map=None, result_map_targets=(), **kw
3065 ):
3066 if add_to_result_map is not None:
3067 result_map_targets += (unary,)
3068 kw["add_to_result_map"] = add_to_result_map
3069 kw["result_map_targets"] = result_map_targets
3070
3071 if unary.operator:
3072 if unary.modifier:
3073 raise exc.CompileError(
3074 "Unary expression does not support operator "
3075 "and modifier simultaneously"
3076 )
3077 disp = self._get_operator_dispatch(
3078 unary.operator, "unary", "operator"
3079 )
3080 if disp:
3081 return disp(unary, unary.operator, **kw)
3082 else:
3083 return self._generate_generic_unary_operator(
3084 unary, OPERATORS[unary.operator], **kw
3085 )
3086 elif unary.modifier:
3087 disp = self._get_operator_dispatch(
3088 unary.modifier, "unary", "modifier"
3089 )
3090 if disp:
3091 return disp(unary, unary.modifier, **kw)
3092 else:
3093 return self._generate_generic_unary_modifier(
3094 unary, OPERATORS[unary.modifier], **kw
3095 )
3096 else:
3097 raise exc.CompileError(
3098 "Unary expression has no operator or modifier"
3099 )
3100
3101 def visit_truediv_binary(self, binary, operator, **kw):
3102 if self.dialect.div_is_floordiv:
3103 return (
3104 self.process(binary.left, **kw)
3105 + " / "
3106 # TODO: would need a fast cast again here,
3107 # unless we want to use an implicit cast like "+ 0.0"
3108 + self.process(
3109 elements.Cast(
3110 binary.right,
3111 (
3112 binary.right.type
3113 if binary.right.type._type_affinity
3114 in (sqltypes.Numeric, sqltypes.Float)
3115 else sqltypes.Numeric()
3116 ),
3117 ),
3118 **kw,
3119 )
3120 )
3121 else:
3122 return (
3123 self.process(binary.left, **kw)
3124 + " / "
3125 + self.process(binary.right, **kw)
3126 )
3127
3128 def visit_floordiv_binary(self, binary, operator, **kw):
3129 if (
3130 self.dialect.div_is_floordiv
3131 and binary.right.type._type_affinity is sqltypes.Integer
3132 ):
3133 return (
3134 self.process(binary.left, **kw)
3135 + " / "
3136 + self.process(binary.right, **kw)
3137 )
3138 else:
3139 return "FLOOR(%s)" % (
3140 self.process(binary.left, **kw)
3141 + " / "
3142 + self.process(binary.right, **kw)
3143 )
3144
3145 def visit_is_true_unary_operator(self, element, operator, **kw):
3146 if (
3147 element._is_implicitly_boolean
3148 or self.dialect.supports_native_boolean
3149 ):
3150 return self.process(element.element, **kw)
3151 else:
3152 return "%s = 1" % self.process(element.element, **kw)
3153
3154 def visit_is_false_unary_operator(self, element, operator, **kw):
3155 if (
3156 element._is_implicitly_boolean
3157 or self.dialect.supports_native_boolean
3158 ):
3159 return "NOT %s" % self.process(element.element, **kw)
3160 else:
3161 return "%s = 0" % self.process(element.element, **kw)
3162
3163 def visit_not_match_op_binary(self, binary, operator, **kw):
3164 return "NOT %s" % self.visit_binary(
3165 binary, override_operator=operators.match_op
3166 )
3167
3168 def visit_not_in_op_binary(self, binary, operator, **kw):
3169 # The brackets are required in the NOT IN operation because the empty
3170 # case is handled using the form "(col NOT IN (null) OR 1 = 1)".
3171 # The presence of the OR makes the brackets required.
3172 return "(%s)" % self._generate_generic_binary(
3173 binary, OPERATORS[operator], **kw
3174 )
3175
3176 def visit_empty_set_op_expr(self, type_, expand_op, **kw):
3177 if expand_op is operators.not_in_op:
3178 if len(type_) > 1:
3179 return "(%s)) OR (1 = 1" % (
3180 ", ".join("NULL" for element in type_)
3181 )
3182 else:
3183 return "NULL) OR (1 = 1"
3184 elif expand_op is operators.in_op:
3185 if len(type_) > 1:
3186 return "(%s)) AND (1 != 1" % (
3187 ", ".join("NULL" for element in type_)
3188 )
3189 else:
3190 return "NULL) AND (1 != 1"
3191 else:
3192 return self.visit_empty_set_expr(type_)
3193
3194 def visit_empty_set_expr(self, element_types, **kw):
3195 raise NotImplementedError(
3196 "Dialect '%s' does not support empty set expression."
3197 % self.dialect.name
3198 )
3199
3200 def _literal_execute_expanding_parameter_literal_binds(
3201 self, parameter, values, bind_expression_template=None
3202 ):
3203 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
3204
3205 if not values:
3206 # empty IN expression. note we don't need to use
3207 # bind_expression_template here because there are no
3208 # expressions to render.
3209
3210 if typ_dialect_impl._is_tuple_type:
3211 replacement_expression = (
3212 "VALUES " if self.dialect.tuple_in_values else ""
3213 ) + self.visit_empty_set_op_expr(
3214 parameter.type.types, parameter.expand_op
3215 )
3216
3217 else:
3218 replacement_expression = self.visit_empty_set_op_expr(
3219 [parameter.type], parameter.expand_op
3220 )
3221
3222 elif typ_dialect_impl._is_tuple_type or (
3223 typ_dialect_impl._isnull
3224 and isinstance(values[0], collections_abc.Sequence)
3225 and not isinstance(values[0], (str, bytes))
3226 ):
3227 if typ_dialect_impl._has_bind_expression:
3228 raise NotImplementedError(
3229 "bind_expression() on TupleType not supported with "
3230 "literal_binds"
3231 )
3232
3233 replacement_expression = (
3234 "VALUES " if self.dialect.tuple_in_values else ""
3235 ) + ", ".join(
3236 "(%s)"
3237 % (
3238 ", ".join(
3239 self.render_literal_value(value, param_type)
3240 for value, param_type in zip(
3241 tuple_element, parameter.type.types
3242 )
3243 )
3244 )
3245 for i, tuple_element in enumerate(values)
3246 )
3247 else:
3248 if bind_expression_template:
3249 post_compile_pattern = self._post_compile_pattern
3250 m = post_compile_pattern.search(bind_expression_template)
3251 assert m and m.group(
3252 2
3253 ), "unexpected format for expanding parameter"
3254
3255 tok = m.group(2).split("~~")
3256 be_left, be_right = tok[1], tok[3]
3257 replacement_expression = ", ".join(
3258 "%s%s%s"
3259 % (
3260 be_left,
3261 self.render_literal_value(value, parameter.type),
3262 be_right,
3263 )
3264 for value in values
3265 )
3266 else:
3267 replacement_expression = ", ".join(
3268 self.render_literal_value(value, parameter.type)
3269 for value in values
3270 )
3271
3272 return (), replacement_expression
3273
3274 def _literal_execute_expanding_parameter(self, name, parameter, values):
3275 if parameter.literal_execute:
3276 return self._literal_execute_expanding_parameter_literal_binds(
3277 parameter, values
3278 )
3279
3280 dialect = self.dialect
3281 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect)
3282
3283 if self._numeric_binds:
3284 bind_template = self.compilation_bindtemplate
3285 else:
3286 bind_template = self.bindtemplate
3287
3288 if (
3289 self.dialect._bind_typing_render_casts
3290 and typ_dialect_impl.render_bind_cast
3291 ):
3292
3293 def _render_bindtemplate(name):
3294 return self.render_bind_cast(
3295 parameter.type,
3296 typ_dialect_impl,
3297 bind_template % {"name": name},
3298 )
3299
3300 else:
3301
3302 def _render_bindtemplate(name):
3303 return bind_template % {"name": name}
3304
3305 if not values:
3306 to_update = []
3307 if typ_dialect_impl._is_tuple_type:
3308 replacement_expression = self.visit_empty_set_op_expr(
3309 parameter.type.types, parameter.expand_op
3310 )
3311 else:
3312 replacement_expression = self.visit_empty_set_op_expr(
3313 [parameter.type], parameter.expand_op
3314 )
3315
3316 elif typ_dialect_impl._is_tuple_type or (
3317 typ_dialect_impl._isnull
3318 and isinstance(values[0], collections_abc.Sequence)
3319 and not isinstance(values[0], (str, bytes))
3320 ):
3321 assert not typ_dialect_impl._is_array
3322 to_update = [
3323 ("%s_%s_%s" % (name, i, j), value)
3324 for i, tuple_element in enumerate(values, 1)
3325 for j, value in enumerate(tuple_element, 1)
3326 ]
3327
3328 replacement_expression = (
3329 "VALUES " if dialect.tuple_in_values else ""
3330 ) + ", ".join(
3331 "(%s)"
3332 % (
3333 ", ".join(
3334 _render_bindtemplate(
3335 to_update[i * len(tuple_element) + j][0]
3336 )
3337 for j, value in enumerate(tuple_element)
3338 )
3339 )
3340 for i, tuple_element in enumerate(values)
3341 )
3342 else:
3343 to_update = [
3344 ("%s_%s" % (name, i), value)
3345 for i, value in enumerate(values, 1)
3346 ]
3347 replacement_expression = ", ".join(
3348 _render_bindtemplate(key) for key, value in to_update
3349 )
3350
3351 return to_update, replacement_expression
3352
3353 def visit_binary(
3354 self,
3355 binary,
3356 override_operator=None,
3357 eager_grouping=False,
3358 from_linter=None,
3359 lateral_from_linter=None,
3360 **kw,
3361 ):
3362 if from_linter and operators.is_comparison(binary.operator):
3363 if lateral_from_linter is not None:
3364 enclosing_lateral = kw["enclosing_lateral"]
3365 lateral_from_linter.edges.update(
3366 itertools.product(
3367 _de_clone(
3368 binary.left._from_objects + [enclosing_lateral]
3369 ),
3370 _de_clone(
3371 binary.right._from_objects + [enclosing_lateral]
3372 ),
3373 )
3374 )
3375 else:
3376 from_linter.edges.update(
3377 itertools.product(
3378 _de_clone(binary.left._from_objects),
3379 _de_clone(binary.right._from_objects),
3380 )
3381 )
3382
3383 # don't allow "? = ?" to render
3384 if (
3385 self.ansi_bind_rules
3386 and isinstance(binary.left, elements.BindParameter)
3387 and isinstance(binary.right, elements.BindParameter)
3388 ):
3389 kw["literal_execute"] = True
3390
3391 operator_ = override_operator or binary.operator
3392 disp = self._get_operator_dispatch(operator_, "binary", None)
3393 if disp:
3394 return disp(binary, operator_, **kw)
3395 else:
3396 try:
3397 opstring = OPERATORS[operator_]
3398 except KeyError as err:
3399 raise exc.UnsupportedCompilationError(self, operator_) from err
3400 else:
3401 return self._generate_generic_binary(
3402 binary,
3403 opstring,
3404 from_linter=from_linter,
3405 lateral_from_linter=lateral_from_linter,
3406 **kw,
3407 )
3408
3409 def visit_function_as_comparison_op_binary(self, element, operator, **kw):
3410 return self.process(element.sql_function, **kw)
3411
3412 def visit_mod_binary(self, binary, operator, **kw):
3413 if self.preparer._double_percents:
3414 return (
3415 self.process(binary.left, **kw)
3416 + " %% "
3417 + self.process(binary.right, **kw)
3418 )
3419 else:
3420 return (
3421 self.process(binary.left, **kw)
3422 + " % "
3423 + self.process(binary.right, **kw)
3424 )
3425
3426 def visit_custom_op_binary(self, element, operator, **kw):
3427 kw["eager_grouping"] = operator.eager_grouping
3428 return self._generate_generic_binary(
3429 element,
3430 " " + self.escape_literal_column(operator.opstring) + " ",
3431 **kw,
3432 )
3433
3434 def visit_custom_op_unary_operator(self, element, operator, **kw):
3435 return self._generate_generic_unary_operator(
3436 element, self.escape_literal_column(operator.opstring) + " ", **kw
3437 )
3438
3439 def visit_custom_op_unary_modifier(self, element, operator, **kw):
3440 return self._generate_generic_unary_modifier(
3441 element, " " + self.escape_literal_column(operator.opstring), **kw
3442 )
3443
3444 def _generate_generic_binary(
3445 self,
3446 binary: BinaryExpression[Any],
3447 opstring: str,
3448 eager_grouping: bool = False,
3449 **kw: Any,
3450 ) -> str:
3451 _in_operator_expression = kw.get("_in_operator_expression", False)
3452
3453 kw["_in_operator_expression"] = True
3454 kw["_binary_op"] = binary.operator
3455 text = (
3456 binary.left._compiler_dispatch(
3457 self, eager_grouping=eager_grouping, **kw
3458 )
3459 + opstring
3460 + binary.right._compiler_dispatch(
3461 self, eager_grouping=eager_grouping, **kw
3462 )
3463 )
3464
3465 if _in_operator_expression and eager_grouping:
3466 text = "(%s)" % text
3467 return text
3468
3469 def _generate_generic_unary_operator(self, unary, opstring, **kw):
3470 return opstring + unary.element._compiler_dispatch(self, **kw)
3471
3472 def _generate_generic_unary_modifier(self, unary, opstring, **kw):
3473 return unary.element._compiler_dispatch(self, **kw) + opstring
3474
3475 @util.memoized_property
3476 def _like_percent_literal(self):
3477 return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
3478
3479 def visit_ilike_case_insensitive_operand(self, element, **kw):
3480 return f"lower({element.element._compiler_dispatch(self, **kw)})"
3481
3482 def visit_contains_op_binary(self, binary, operator, **kw):
3483 binary = binary._clone()
3484 percent = self._like_percent_literal
3485 binary.right = percent.concat(binary.right).concat(percent)
3486 return self.visit_like_op_binary(binary, operator, **kw)
3487
3488 def visit_not_contains_op_binary(self, binary, operator, **kw):
3489 binary = binary._clone()
3490 percent = self._like_percent_literal
3491 binary.right = percent.concat(binary.right).concat(percent)
3492 return self.visit_not_like_op_binary(binary, operator, **kw)
3493
3494 def visit_icontains_op_binary(self, binary, operator, **kw):
3495 binary = binary._clone()
3496 percent = self._like_percent_literal
3497 binary.left = ilike_case_insensitive(binary.left)
3498 binary.right = percent.concat(
3499 ilike_case_insensitive(binary.right)
3500 ).concat(percent)
3501 return self.visit_ilike_op_binary(binary, operator, **kw)
3502
3503 def visit_not_icontains_op_binary(self, binary, operator, **kw):
3504 binary = binary._clone()
3505 percent = self._like_percent_literal
3506 binary.left = ilike_case_insensitive(binary.left)
3507 binary.right = percent.concat(
3508 ilike_case_insensitive(binary.right)
3509 ).concat(percent)
3510 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3511
3512 def visit_startswith_op_binary(self, binary, operator, **kw):
3513 binary = binary._clone()
3514 percent = self._like_percent_literal
3515 binary.right = percent._rconcat(binary.right)
3516 return self.visit_like_op_binary(binary, operator, **kw)
3517
3518 def visit_not_startswith_op_binary(self, binary, operator, **kw):
3519 binary = binary._clone()
3520 percent = self._like_percent_literal
3521 binary.right = percent._rconcat(binary.right)
3522 return self.visit_not_like_op_binary(binary, operator, **kw)
3523
3524 def visit_istartswith_op_binary(self, binary, operator, **kw):
3525 binary = binary._clone()
3526 percent = self._like_percent_literal
3527 binary.left = ilike_case_insensitive(binary.left)
3528 binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
3529 return self.visit_ilike_op_binary(binary, operator, **kw)
3530
3531 def visit_not_istartswith_op_binary(self, binary, operator, **kw):
3532 binary = binary._clone()
3533 percent = self._like_percent_literal
3534 binary.left = ilike_case_insensitive(binary.left)
3535 binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
3536 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3537
3538 def visit_endswith_op_binary(self, binary, operator, **kw):
3539 binary = binary._clone()
3540 percent = self._like_percent_literal
3541 binary.right = percent.concat(binary.right)
3542 return self.visit_like_op_binary(binary, operator, **kw)
3543
3544 def visit_not_endswith_op_binary(self, binary, operator, **kw):
3545 binary = binary._clone()
3546 percent = self._like_percent_literal
3547 binary.right = percent.concat(binary.right)
3548 return self.visit_not_like_op_binary(binary, operator, **kw)
3549
3550 def visit_iendswith_op_binary(self, binary, operator, **kw):
3551 binary = binary._clone()
3552 percent = self._like_percent_literal
3553 binary.left = ilike_case_insensitive(binary.left)
3554 binary.right = percent.concat(ilike_case_insensitive(binary.right))
3555 return self.visit_ilike_op_binary(binary, operator, **kw)
3556
3557 def visit_not_iendswith_op_binary(self, binary, operator, **kw):
3558 binary = binary._clone()
3559 percent = self._like_percent_literal
3560 binary.left = ilike_case_insensitive(binary.left)
3561 binary.right = percent.concat(ilike_case_insensitive(binary.right))
3562 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3563
3564 def visit_like_op_binary(self, binary, operator, **kw):
3565 escape = binary.modifiers.get("escape", None)
3566
3567 return "%s LIKE %s" % (
3568 binary.left._compiler_dispatch(self, **kw),
3569 binary.right._compiler_dispatch(self, **kw),
3570 ) + (
3571 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
3572 if escape is not None
3573 else ""
3574 )
3575
3576 def visit_not_like_op_binary(self, binary, operator, **kw):
3577 escape = binary.modifiers.get("escape", None)
3578 return "%s NOT LIKE %s" % (
3579 binary.left._compiler_dispatch(self, **kw),
3580 binary.right._compiler_dispatch(self, **kw),
3581 ) + (
3582 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
3583 if escape is not None
3584 else ""
3585 )
3586
3587 def visit_ilike_op_binary(self, binary, operator, **kw):
3588 if operator is operators.ilike_op:
3589 binary = binary._clone()
3590 binary.left = ilike_case_insensitive(binary.left)
3591 binary.right = ilike_case_insensitive(binary.right)
3592 # else we assume ilower() has been applied
3593
3594 return self.visit_like_op_binary(binary, operator, **kw)
3595
3596 def visit_not_ilike_op_binary(self, binary, operator, **kw):
3597 if operator is operators.not_ilike_op:
3598 binary = binary._clone()
3599 binary.left = ilike_case_insensitive(binary.left)
3600 binary.right = ilike_case_insensitive(binary.right)
3601 # else we assume ilower() has been applied
3602
3603 return self.visit_not_like_op_binary(binary, operator, **kw)
3604
3605 def visit_between_op_binary(self, binary, operator, **kw):
3606 symmetric = binary.modifiers.get("symmetric", False)
3607 return self._generate_generic_binary(
3608 binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
3609 )
3610
3611 def visit_not_between_op_binary(self, binary, operator, **kw):
3612 symmetric = binary.modifiers.get("symmetric", False)
3613 return self._generate_generic_binary(
3614 binary,
3615 " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
3616 **kw,
3617 )
3618
3619 def visit_regexp_match_op_binary(
3620 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3621 ) -> str:
3622 raise exc.CompileError(
3623 "%s dialect does not support regular expressions"
3624 % self.dialect.name
3625 )
3626
3627 def visit_not_regexp_match_op_binary(
3628 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3629 ) -> str:
3630 raise exc.CompileError(
3631 "%s dialect does not support regular expressions"
3632 % self.dialect.name
3633 )
3634
3635 def visit_regexp_replace_op_binary(
3636 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3637 ) -> str:
3638 raise exc.CompileError(
3639 "%s dialect does not support regular expression replacements"
3640 % self.dialect.name
3641 )
3642
3643 def visit_dmltargetcopy(self, element, *, bindmarkers=None, **kw):
3644 if bindmarkers is None:
3645 raise exc.CompileError(
3646 "DML target objects may only be used with "
3647 "compiled INSERT or UPDATE statements"
3648 )
3649
3650 bindmarkers[element.column.key] = element
3651 return f"__BINDMARKER_~~{element.column.key}~~"
3652
3653 def visit_bindparam(
3654 self,
3655 bindparam,
3656 within_columns_clause=False,
3657 literal_binds=False,
3658 skip_bind_expression=False,
3659 literal_execute=False,
3660 render_postcompile=False,
3661 **kwargs,
3662 ):
3663
3664 if not skip_bind_expression:
3665 impl = bindparam.type.dialect_impl(self.dialect)
3666 if impl._has_bind_expression:
3667 bind_expression = impl.bind_expression(bindparam)
3668 wrapped = self.process(
3669 bind_expression,
3670 skip_bind_expression=True,
3671 within_columns_clause=within_columns_clause,
3672 literal_binds=literal_binds and not bindparam.expanding,
3673 literal_execute=literal_execute,
3674 render_postcompile=render_postcompile,
3675 **kwargs,
3676 )
3677 if bindparam.expanding:
3678 # for postcompile w/ expanding, move the "wrapped" part
3679 # of this into the inside
3680
3681 m = re.match(
3682 r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
3683 )
3684 assert m, "unexpected format for expanding parameter"
3685 wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
3686 m.group(2),
3687 m.group(1),
3688 m.group(3),
3689 )
3690
3691 if literal_binds:
3692 ret = self.render_literal_bindparam(
3693 bindparam,
3694 within_columns_clause=True,
3695 bind_expression_template=wrapped,
3696 **kwargs,
3697 )
3698 return f"({ret})"
3699
3700 return wrapped
3701
3702 if not literal_binds:
3703 literal_execute = (
3704 literal_execute
3705 or bindparam.literal_execute
3706 or (within_columns_clause and self.ansi_bind_rules)
3707 )
3708 post_compile = literal_execute or bindparam.expanding
3709 else:
3710 post_compile = False
3711
3712 if literal_binds:
3713 ret = self.render_literal_bindparam(
3714 bindparam, within_columns_clause=True, **kwargs
3715 )
3716 if bindparam.expanding:
3717 ret = f"({ret})"
3718 return ret
3719
3720 name = self._truncate_bindparam(bindparam)
3721
3722 if name in self.binds:
3723 existing = self.binds[name]
3724 if existing is not bindparam:
3725 if (
3726 (existing.unique or bindparam.unique)
3727 and not existing.proxy_set.intersection(
3728 bindparam.proxy_set
3729 )
3730 and not existing._cloned_set.intersection(
3731 bindparam._cloned_set
3732 )
3733 ):
3734 raise exc.CompileError(
3735 "Bind parameter '%s' conflicts with "
3736 "unique bind parameter of the same name" % name
3737 )
3738 elif existing.expanding != bindparam.expanding:
3739 raise exc.CompileError(
3740 "Can't reuse bound parameter name '%s' in both "
3741 "'expanding' (e.g. within an IN expression) and "
3742 "non-expanding contexts. If this parameter is to "
3743 "receive a list/array value, set 'expanding=True' on "
3744 "it for expressions that aren't IN, otherwise use "
3745 "a different parameter name." % (name,)
3746 )
3747 elif existing._is_crud or bindparam._is_crud:
3748 if existing._is_crud and bindparam._is_crud:
3749 # TODO: this condition is not well understood.
3750 # see tests in test/sql/test_update.py
3751 raise exc.CompileError(
3752 "Encountered unsupported case when compiling an "
3753 "INSERT or UPDATE statement. If this is a "
3754 "multi-table "
3755 "UPDATE statement, please provide string-named "
3756 "arguments to the "
3757 "values() method with distinct names; support for "
3758 "multi-table UPDATE statements that "
3759 "target multiple tables for UPDATE is very "
3760 "limited",
3761 )
3762 else:
3763 raise exc.CompileError(
3764 f"bindparam() name '{bindparam.key}' is reserved "
3765 "for automatic usage in the VALUES or SET "
3766 "clause of this "
3767 "insert/update statement. Please use a "
3768 "name other than column name when using "
3769 "bindparam() "
3770 "with insert() or update() (for example, "
3771 f"'b_{bindparam.key}')."
3772 )
3773
3774 self.binds[bindparam.key] = self.binds[name] = bindparam
3775
3776 # if we are given a cache key that we're going to match against,
3777 # relate the bindparam here to one that is most likely present
3778 # in the "extracted params" portion of the cache key. this is used
3779 # to set up a positional mapping that is used to determine the
3780 # correct parameters for a subsequent use of this compiled with
3781 # a different set of parameter values. here, we accommodate for
3782 # parameters that may have been cloned both before and after the cache
3783 # key was been generated.
3784 ckbm_tuple = self._cache_key_bind_match
3785
3786 if ckbm_tuple:
3787 ckbm, cksm = ckbm_tuple
3788 for bp in bindparam._cloned_set:
3789 if bp.key in cksm:
3790 cb = cksm[bp.key]
3791 ckbm[cb].append(bindparam)
3792
3793 if bindparam.isoutparam:
3794 self.has_out_parameters = True
3795
3796 if post_compile:
3797 if render_postcompile:
3798 self._render_postcompile = True
3799
3800 if literal_execute:
3801 self.literal_execute_params |= {bindparam}
3802 else:
3803 self.post_compile_params |= {bindparam}
3804
3805 ret = self.bindparam_string(
3806 name,
3807 post_compile=post_compile,
3808 expanding=bindparam.expanding,
3809 bindparam_type=bindparam.type,
3810 **kwargs,
3811 )
3812
3813 if bindparam.expanding:
3814 ret = f"({ret})"
3815
3816 return ret
3817
3818 def render_bind_cast(self, type_, dbapi_type, sqltext):
3819 raise NotImplementedError()
3820
3821 def render_literal_bindparam(
3822 self,
3823 bindparam,
3824 render_literal_value=NO_ARG,
3825 bind_expression_template=None,
3826 **kw,
3827 ):
3828 if render_literal_value is not NO_ARG:
3829 value = render_literal_value
3830 else:
3831 if bindparam.value is None and bindparam.callable is None:
3832 op = kw.get("_binary_op", None)
3833 if op and op not in (operators.is_, operators.is_not):
3834 util.warn_limited(
3835 "Bound parameter '%s' rendering literal NULL in a SQL "
3836 "expression; comparisons to NULL should not use "
3837 "operators outside of 'is' or 'is not'",
3838 (bindparam.key,),
3839 )
3840 return self.process(sqltypes.NULLTYPE, **kw)
3841 value = bindparam.effective_value
3842
3843 if bindparam.expanding:
3844 leep = self._literal_execute_expanding_parameter_literal_binds
3845 to_update, replacement_expr = leep(
3846 bindparam,
3847 value,
3848 bind_expression_template=bind_expression_template,
3849 )
3850 return replacement_expr
3851 else:
3852 return self.render_literal_value(value, bindparam.type)
3853
3854 def render_literal_value(
3855 self, value: Any, type_: sqltypes.TypeEngine[Any]
3856 ) -> str:
3857 """Render the value of a bind parameter as a quoted literal.
3858
3859 This is used for statement sections that do not accept bind parameters
3860 on the target driver/database.
3861
3862 This should be implemented by subclasses using the quoting services
3863 of the DBAPI.
3864
3865 """
3866
3867 if value is None and not type_.should_evaluate_none:
3868 # issue #10535 - handle NULL in the compiler without placing
3869 # this onto each type, except for "evaluate None" types
3870 # (e.g. JSON)
3871 return self.process(elements.Null._instance())
3872
3873 processor = type_._cached_literal_processor(self.dialect)
3874 if processor:
3875 try:
3876 return processor(value)
3877 except Exception as e:
3878 raise exc.CompileError(
3879 f"Could not render literal value "
3880 f'"{sql_util._repr_single_value(value)}" '
3881 f"with datatype "
3882 f"{type_}; see parent stack trace for "
3883 "more detail."
3884 ) from e
3885
3886 else:
3887 raise exc.CompileError(
3888 f"No literal value renderer is available for literal value "
3889 f'"{sql_util._repr_single_value(value)}" '
3890 f"with datatype {type_}"
3891 )
3892
3893 def _truncate_bindparam(self, bindparam):
3894 if bindparam in self.bind_names:
3895 return self.bind_names[bindparam]
3896
3897 bind_name = bindparam.key
3898 if isinstance(bind_name, elements._truncated_label):
3899 bind_name = self._truncated_identifier("bindparam", bind_name)
3900
3901 # add to bind_names for translation
3902 self.bind_names[bindparam] = bind_name
3903
3904 return bind_name
3905
3906 def _truncated_identifier(
3907 self, ident_class: str, name: _truncated_label
3908 ) -> str:
3909 if (ident_class, name) in self.truncated_names:
3910 return self.truncated_names[(ident_class, name)]
3911
3912 anonname = name.apply_map(self.anon_map)
3913
3914 if len(anonname) > self.label_length - 6:
3915 counter = self._truncated_counters.get(ident_class, 1)
3916 truncname = (
3917 anonname[0 : max(self.label_length - 6, 0)]
3918 + "_"
3919 + hex(counter)[2:]
3920 )
3921 self._truncated_counters[ident_class] = counter + 1
3922 else:
3923 truncname = anonname
3924 self.truncated_names[(ident_class, name)] = truncname
3925 return truncname
3926
3927 def _anonymize(self, name: str) -> str:
3928 return name % self.anon_map
3929
3930 def bindparam_string(
3931 self,
3932 name: str,
3933 post_compile: bool = False,
3934 expanding: bool = False,
3935 escaped_from: Optional[str] = None,
3936 bindparam_type: Optional[TypeEngine[Any]] = None,
3937 accumulate_bind_names: Optional[Set[str]] = None,
3938 visited_bindparam: Optional[List[str]] = None,
3939 **kw: Any,
3940 ) -> str:
3941 # TODO: accumulate_bind_names is passed by crud.py to gather
3942 # names on a per-value basis, visited_bindparam is passed by
3943 # visit_insert() to collect all parameters in the statement.
3944 # see if this gathering can be simplified somehow
3945 if accumulate_bind_names is not None:
3946 accumulate_bind_names.add(name)
3947 if visited_bindparam is not None:
3948 visited_bindparam.append(name)
3949
3950 if not escaped_from:
3951 if self._bind_translate_re.search(name):
3952 # not quite the translate use case as we want to
3953 # also get a quick boolean if we even found
3954 # unusual characters in the name
3955 new_name = self._bind_translate_re.sub(
3956 lambda m: self._bind_translate_chars[m.group(0)],
3957 name,
3958 )
3959 escaped_from = name
3960 name = new_name
3961
3962 if escaped_from:
3963 self.escaped_bind_names = self.escaped_bind_names.union(
3964 {escaped_from: name}
3965 )
3966 if post_compile:
3967 ret = "__[POSTCOMPILE_%s]" % name
3968 if expanding:
3969 # for expanding, bound parameters or literal values will be
3970 # rendered per item
3971 return ret
3972
3973 # otherwise, for non-expanding "literal execute", apply
3974 # bind casts as determined by the datatype
3975 if bindparam_type is not None:
3976 type_impl = bindparam_type._unwrapped_dialect_impl(
3977 self.dialect
3978 )
3979 if type_impl.render_literal_cast:
3980 ret = self.render_bind_cast(bindparam_type, type_impl, ret)
3981 return ret
3982 elif self.state is CompilerState.COMPILING:
3983 ret = self.compilation_bindtemplate % {"name": name}
3984 else:
3985 ret = self.bindtemplate % {"name": name}
3986
3987 if (
3988 bindparam_type is not None
3989 and self.dialect._bind_typing_render_casts
3990 ):
3991 type_impl = bindparam_type._unwrapped_dialect_impl(self.dialect)
3992 if type_impl.render_bind_cast:
3993 ret = self.render_bind_cast(bindparam_type, type_impl, ret)
3994
3995 return ret
3996
3997 def _dispatch_independent_ctes(self, stmt, kw):
3998 local_kw = kw.copy()
3999 local_kw.pop("cte_opts", None)
4000 for cte, opt in zip(
4001 stmt._independent_ctes, stmt._independent_ctes_opts
4002 ):
4003 cte._compiler_dispatch(self, cte_opts=opt, **local_kw)
4004
4005 def visit_cte(
4006 self,
4007 cte: CTE,
4008 asfrom: bool = False,
4009 ashint: bool = False,
4010 fromhints: Optional[_FromHintsType] = None,
4011 visiting_cte: Optional[CTE] = None,
4012 from_linter: Optional[FromLinter] = None,
4013 cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
4014 **kwargs: Any,
4015 ) -> Optional[str]:
4016 self_ctes = self._init_cte_state()
4017 assert self_ctes is self.ctes
4018
4019 kwargs["visiting_cte"] = cte
4020
4021 cte_name = cte.name
4022
4023 if isinstance(cte_name, elements._truncated_label):
4024 cte_name = self._truncated_identifier("alias", cte_name)
4025
4026 is_new_cte = True
4027 embedded_in_current_named_cte = False
4028
4029 _reference_cte = cte._get_reference_cte()
4030
4031 nesting = cte.nesting or cte_opts.nesting
4032
4033 # check for CTE already encountered
4034 if _reference_cte in self.level_name_by_cte:
4035 cte_level, _, existing_cte_opts = self.level_name_by_cte[
4036 _reference_cte
4037 ]
4038 assert _ == cte_name
4039
4040 cte_level_name = (cte_level, cte_name)
4041 existing_cte = self.ctes_by_level_name[cte_level_name]
4042
4043 # check if we are receiving it here with a specific
4044 # "nest_here" location; if so, move it to this location
4045
4046 if cte_opts.nesting:
4047 if existing_cte_opts.nesting:
4048 raise exc.CompileError(
4049 "CTE is stated as 'nest_here' in "
4050 "more than one location"
4051 )
4052
4053 old_level_name = (cte_level, cte_name)
4054 cte_level = len(self.stack) if nesting else 1
4055 cte_level_name = new_level_name = (cte_level, cte_name)
4056
4057 del self.ctes_by_level_name[old_level_name]
4058 self.ctes_by_level_name[new_level_name] = existing_cte
4059 self.level_name_by_cte[_reference_cte] = new_level_name + (
4060 cte_opts,
4061 )
4062
4063 else:
4064 cte_level = len(self.stack) if nesting else 1
4065 cte_level_name = (cte_level, cte_name)
4066
4067 if cte_level_name in self.ctes_by_level_name:
4068 existing_cte = self.ctes_by_level_name[cte_level_name]
4069 else:
4070 existing_cte = None
4071
4072 if existing_cte is not None:
4073 embedded_in_current_named_cte = visiting_cte is existing_cte
4074
4075 # we've generated a same-named CTE that we are enclosed in,
4076 # or this is the same CTE. just return the name.
4077 if cte is existing_cte._restates or cte is existing_cte:
4078 is_new_cte = False
4079 elif existing_cte is cte._restates:
4080 # we've generated a same-named CTE that is
4081 # enclosed in us - we take precedence, so
4082 # discard the text for the "inner".
4083 del self_ctes[existing_cte]
4084
4085 existing_cte_reference_cte = existing_cte._get_reference_cte()
4086
4087 assert existing_cte_reference_cte is _reference_cte
4088 assert existing_cte_reference_cte is existing_cte
4089
4090 del self.level_name_by_cte[existing_cte_reference_cte]
4091 else:
4092 if (
4093 # if the two CTEs have the same hash, which we expect
4094 # here means that one/both is an annotated of the other
4095 (hash(cte) == hash(existing_cte))
4096 # or...
4097 or (
4098 (
4099 # if they are clones, i.e. they came from the ORM
4100 # or some other visit method
4101 cte._is_clone_of is not None
4102 or existing_cte._is_clone_of is not None
4103 )
4104 # and are deep-copy identical
4105 and cte.compare(existing_cte)
4106 )
4107 ):
4108 # then consider these two CTEs the same
4109 is_new_cte = False
4110 else:
4111 # otherwise these are two CTEs that either will render
4112 # differently, or were indicated separately by the user,
4113 # with the same name
4114 raise exc.CompileError(
4115 "Multiple, unrelated CTEs found with "
4116 "the same name: %r" % cte_name
4117 )
4118
4119 if not asfrom and not is_new_cte:
4120 return None
4121
4122 if cte._cte_alias is not None:
4123 pre_alias_cte = cte._cte_alias
4124 cte_pre_alias_name = cte._cte_alias.name
4125 if isinstance(cte_pre_alias_name, elements._truncated_label):
4126 cte_pre_alias_name = self._truncated_identifier(
4127 "alias", cte_pre_alias_name
4128 )
4129 else:
4130 pre_alias_cte = cte
4131 cte_pre_alias_name = None
4132
4133 if is_new_cte:
4134 self.ctes_by_level_name[cte_level_name] = cte
4135 self.level_name_by_cte[_reference_cte] = cte_level_name + (
4136 cte_opts,
4137 )
4138
4139 if pre_alias_cte not in self.ctes:
4140 self.visit_cte(pre_alias_cte, **kwargs)
4141
4142 if not cte_pre_alias_name and cte not in self_ctes:
4143 if cte.recursive:
4144 self.ctes_recursive = True
4145 text = self.preparer.format_alias(cte, cte_name)
4146 if cte.recursive or cte.element.name_cte_columns:
4147 col_source = cte.element
4148
4149 # TODO: can we get at the .columns_plus_names collection
4150 # that is already (or will be?) generated for the SELECT
4151 # rather than calling twice?
4152 recur_cols = [
4153 # TODO: proxy_name is not technically safe,
4154 # see test_cte->
4155 # test_with_recursive_no_name_currently_buggy. not
4156 # clear what should be done with such a case
4157 fallback_label_name or proxy_name
4158 for (
4159 _,
4160 proxy_name,
4161 fallback_label_name,
4162 c,
4163 repeated,
4164 ) in (col_source._generate_columns_plus_names(True))
4165 if not repeated
4166 ]
4167
4168 text += "(%s)" % (
4169 ", ".join(
4170 self.preparer.format_label_name(
4171 ident, anon_map=self.anon_map
4172 )
4173 for ident in recur_cols
4174 )
4175 )
4176
4177 assert kwargs.get("subquery", False) is False
4178
4179 if not self.stack:
4180 # toplevel, this is a stringify of the
4181 # cte directly. just compile the inner
4182 # the way alias() does.
4183 return cte.element._compiler_dispatch(
4184 self, asfrom=asfrom, **kwargs
4185 )
4186 else:
4187 prefixes = self._generate_prefixes(
4188 cte, cte._prefixes, **kwargs
4189 )
4190 inner = cte.element._compiler_dispatch(
4191 self, asfrom=True, **kwargs
4192 )
4193
4194 text += " AS %s\n(%s)" % (prefixes, inner)
4195
4196 if cte._suffixes:
4197 text += " " + self._generate_prefixes(
4198 cte, cte._suffixes, **kwargs
4199 )
4200
4201 self_ctes[cte] = text
4202
4203 if asfrom:
4204 if from_linter:
4205 from_linter.froms[cte._de_clone()] = cte_name
4206
4207 if not is_new_cte and embedded_in_current_named_cte:
4208 return self.preparer.format_alias(cte, cte_name)
4209
4210 if cte_pre_alias_name:
4211 text = self.preparer.format_alias(cte, cte_pre_alias_name)
4212 if self.preparer._requires_quotes(cte_name):
4213 cte_name = self.preparer.quote(cte_name)
4214 text += self.get_render_as_alias_suffix(cte_name)
4215 return text # type: ignore[no-any-return]
4216 else:
4217 return self.preparer.format_alias(cte, cte_name)
4218
4219 return None
4220
4221 def visit_table_valued_alias(self, element, **kw):
4222 if element.joins_implicitly:
4223 kw["from_linter"] = None
4224 if element._is_lateral:
4225 return self.visit_lateral(element, **kw)
4226 else:
4227 return self.visit_alias(element, **kw)
4228
4229 def visit_table_valued_column(self, element, **kw):
4230 return self.visit_column(element, **kw)
4231
4232 def visit_alias(
4233 self,
4234 alias,
4235 asfrom=False,
4236 ashint=False,
4237 iscrud=False,
4238 fromhints=None,
4239 subquery=False,
4240 lateral=False,
4241 enclosing_alias=None,
4242 from_linter=None,
4243 **kwargs,
4244 ):
4245 if lateral:
4246 if "enclosing_lateral" not in kwargs:
4247 # if lateral is set and enclosing_lateral is not
4248 # present, we assume we are being called directly
4249 # from visit_lateral() and we need to set enclosing_lateral.
4250 assert alias._is_lateral
4251 kwargs["enclosing_lateral"] = alias
4252
4253 # for lateral objects, we track a second from_linter that is...
4254 # lateral! to the level above us.
4255 if (
4256 from_linter
4257 and "lateral_from_linter" not in kwargs
4258 and "enclosing_lateral" in kwargs
4259 ):
4260 kwargs["lateral_from_linter"] = from_linter
4261
4262 if enclosing_alias is not None and enclosing_alias.element is alias:
4263 inner = alias.element._compiler_dispatch(
4264 self,
4265 asfrom=asfrom,
4266 ashint=ashint,
4267 iscrud=iscrud,
4268 fromhints=fromhints,
4269 lateral=lateral,
4270 enclosing_alias=alias,
4271 **kwargs,
4272 )
4273 if subquery and (asfrom or lateral):
4274 inner = "(%s)" % (inner,)
4275 return inner
4276 else:
4277 kwargs["enclosing_alias"] = alias
4278
4279 if asfrom or ashint:
4280 if isinstance(alias.name, elements._truncated_label):
4281 alias_name = self._truncated_identifier("alias", alias.name)
4282 else:
4283 alias_name = alias.name
4284
4285 if ashint:
4286 return self.preparer.format_alias(alias, alias_name)
4287 elif asfrom:
4288 if from_linter:
4289 from_linter.froms[alias._de_clone()] = alias_name
4290
4291 inner = alias.element._compiler_dispatch(
4292 self, asfrom=True, lateral=lateral, **kwargs
4293 )
4294 if subquery:
4295 inner = "(%s)" % (inner,)
4296
4297 ret = inner + self.get_render_as_alias_suffix(
4298 self.preparer.format_alias(alias, alias_name)
4299 )
4300
4301 if alias._supports_derived_columns and alias._render_derived:
4302 ret += "(%s)" % (
4303 ", ".join(
4304 "%s%s"
4305 % (
4306 self.preparer.quote(col.name),
4307 (
4308 " %s"
4309 % self.dialect.type_compiler_instance.process(
4310 col.type, **kwargs
4311 )
4312 if alias._render_derived_w_types
4313 else ""
4314 ),
4315 )
4316 for col in alias.c
4317 )
4318 )
4319
4320 if fromhints and alias in fromhints:
4321 ret = self.format_from_hint_text(
4322 ret, alias, fromhints[alias], iscrud
4323 )
4324
4325 return ret
4326 else:
4327 # note we cancel the "subquery" flag here as well
4328 return alias.element._compiler_dispatch(
4329 self, lateral=lateral, **kwargs
4330 )
4331
4332 def visit_subquery(self, subquery, **kw):
4333 kw["subquery"] = True
4334 return self.visit_alias(subquery, **kw)
4335
4336 def visit_lateral(self, lateral_, **kw):
4337 kw["lateral"] = True
4338 return "LATERAL %s" % self.visit_alias(lateral_, **kw)
4339
4340 def visit_tablesample(self, tablesample, asfrom=False, **kw):
4341 text = "%s TABLESAMPLE %s" % (
4342 self.visit_alias(tablesample, asfrom=True, **kw),
4343 tablesample._get_method()._compiler_dispatch(self, **kw),
4344 )
4345
4346 if tablesample.seed is not None:
4347 text += " REPEATABLE (%s)" % (
4348 tablesample.seed._compiler_dispatch(self, **kw)
4349 )
4350
4351 return text
4352
4353 def _render_values(self, element, **kw):
4354 kw.setdefault("literal_binds", element.literal_binds)
4355 tuples = ", ".join(
4356 self.process(
4357 elements.Tuple(
4358 types=element._column_types, *elem
4359 ).self_group(),
4360 **kw,
4361 )
4362 for chunk in element._data
4363 for elem in chunk
4364 )
4365 return f"VALUES {tuples}"
4366
4367 def visit_values(
4368 self, element, asfrom=False, from_linter=None, visiting_cte=None, **kw
4369 ):
4370
4371 if element._independent_ctes:
4372 self._dispatch_independent_ctes(element, kw)
4373
4374 v = self._render_values(element, **kw)
4375
4376 if element._unnamed:
4377 name = None
4378 elif isinstance(element.name, elements._truncated_label):
4379 name = self._truncated_identifier("values", element.name)
4380 else:
4381 name = element.name
4382
4383 if element._is_lateral:
4384 lateral = "LATERAL "
4385 else:
4386 lateral = ""
4387
4388 if asfrom:
4389 if from_linter:
4390 from_linter.froms[element._de_clone()] = (
4391 name if name is not None else "(unnamed VALUES element)"
4392 )
4393
4394 if visiting_cte is not None and visiting_cte.element is element:
4395 if element._is_lateral:
4396 raise exc.CompileError(
4397 "Can't use a LATERAL VALUES expression inside of a CTE"
4398 )
4399 elif name:
4400 kw["include_table"] = False
4401 v = "%s(%s)%s (%s)" % (
4402 lateral,
4403 v,
4404 self.get_render_as_alias_suffix(self.preparer.quote(name)),
4405 (
4406 ", ".join(
4407 c._compiler_dispatch(self, **kw)
4408 for c in element.columns
4409 )
4410 ),
4411 )
4412 else:
4413 v = "%s(%s)" % (lateral, v)
4414 return v
4415
4416 def visit_scalar_values(self, element, **kw):
4417 return f"({self._render_values(element, **kw)})"
4418
4419 def get_render_as_alias_suffix(self, alias_name_text):
4420 return " AS " + alias_name_text
4421
4422 def _add_to_result_map(
4423 self,
4424 keyname: str,
4425 name: str,
4426 objects: Tuple[Any, ...],
4427 type_: TypeEngine[Any],
4428 ) -> None:
4429
4430 # note objects must be non-empty for cursor.py to handle the
4431 # collection properly
4432 assert objects
4433
4434 if keyname is None or keyname == "*":
4435 self._ordered_columns = False
4436 self._ad_hoc_textual = True
4437 if type_._is_tuple_type:
4438 raise exc.CompileError(
4439 "Most backends don't support SELECTing "
4440 "from a tuple() object. If this is an ORM query, "
4441 "consider using the Bundle object."
4442 )
4443 self._result_columns.append(
4444 ResultColumnsEntry(keyname, name, objects, type_)
4445 )
4446
4447 def _label_returning_column(
4448 self, stmt, column, populate_result_map, column_clause_args=None, **kw
4449 ):
4450 """Render a column with necessary labels inside of a RETURNING clause.
4451
4452 This method is provided for individual dialects in place of calling
4453 the _label_select_column method directly, so that the two use cases
4454 of RETURNING vs. SELECT can be disambiguated going forward.
4455
4456 .. versionadded:: 1.4.21
4457
4458 """
4459 return self._label_select_column(
4460 None,
4461 column,
4462 populate_result_map,
4463 False,
4464 {} if column_clause_args is None else column_clause_args,
4465 **kw,
4466 )
4467
4468 def _label_select_column(
4469 self,
4470 select,
4471 column,
4472 populate_result_map,
4473 asfrom,
4474 column_clause_args,
4475 name=None,
4476 proxy_name=None,
4477 fallback_label_name=None,
4478 within_columns_clause=True,
4479 column_is_repeated=False,
4480 need_column_expressions=False,
4481 include_table=True,
4482 ):
4483 """produce labeled columns present in a select()."""
4484 impl = column.type.dialect_impl(self.dialect)
4485
4486 if impl._has_column_expression and (
4487 need_column_expressions or populate_result_map
4488 ):
4489 col_expr = impl.column_expression(column)
4490 else:
4491 col_expr = column
4492
4493 if populate_result_map:
4494 # pass an "add_to_result_map" callable into the compilation
4495 # of embedded columns. this collects information about the
4496 # column as it will be fetched in the result and is coordinated
4497 # with cursor.description when the query is executed.
4498 add_to_result_map = self._add_to_result_map
4499
4500 # if the SELECT statement told us this column is a repeat,
4501 # wrap the callable with one that prevents the addition of the
4502 # targets
4503 if column_is_repeated:
4504 _add_to_result_map = add_to_result_map
4505
4506 def add_to_result_map(keyname, name, objects, type_):
4507 _add_to_result_map(keyname, name, (keyname,), type_)
4508
4509 # if we redefined col_expr for type expressions, wrap the
4510 # callable with one that adds the original column to the targets
4511 elif col_expr is not column:
4512 _add_to_result_map = add_to_result_map
4513
4514 def add_to_result_map(keyname, name, objects, type_):
4515 _add_to_result_map(
4516 keyname, name, (column,) + objects, type_
4517 )
4518
4519 else:
4520 add_to_result_map = None
4521
4522 # this method is used by some of the dialects for RETURNING,
4523 # which has different inputs. _label_returning_column was added
4524 # as the better target for this now however for 1.4 we will keep
4525 # _label_select_column directly compatible with this use case.
4526 # these assertions right now set up the current expected inputs
4527 assert within_columns_clause, (
4528 "_label_select_column is only relevant within "
4529 "the columns clause of a SELECT or RETURNING"
4530 )
4531 if isinstance(column, elements.Label):
4532 if col_expr is not column:
4533 result_expr = _CompileLabel(
4534 col_expr, column.name, alt_names=(column.element,)
4535 )
4536 else:
4537 result_expr = col_expr
4538
4539 elif name:
4540 # here, _columns_plus_names has determined there's an explicit
4541 # label name we need to use. this is the default for
4542 # tablenames_plus_columnnames as well as when columns are being
4543 # deduplicated on name
4544
4545 assert (
4546 proxy_name is not None
4547 ), "proxy_name is required if 'name' is passed"
4548
4549 result_expr = _CompileLabel(
4550 col_expr,
4551 name,
4552 alt_names=(
4553 proxy_name,
4554 # this is a hack to allow legacy result column lookups
4555 # to work as they did before; this goes away in 2.0.
4556 # TODO: this only seems to be tested indirectly
4557 # via test/orm/test_deprecations.py. should be a
4558 # resultset test for this
4559 column._tq_label,
4560 ),
4561 )
4562 else:
4563 # determine here whether this column should be rendered in
4564 # a labelled context or not, as we were given no required label
4565 # name from the caller. Here we apply heuristics based on the kind
4566 # of SQL expression involved.
4567
4568 if col_expr is not column:
4569 # type-specific expression wrapping the given column,
4570 # so we render a label
4571 render_with_label = True
4572 elif isinstance(column, elements.ColumnClause):
4573 # table-bound column, we render its name as a label if we are
4574 # inside of a subquery only
4575 render_with_label = (
4576 asfrom
4577 and not column.is_literal
4578 and column.table is not None
4579 )
4580 elif isinstance(column, elements.TextClause):
4581 render_with_label = False
4582 elif isinstance(column, elements.UnaryExpression):
4583 # unary expression. notes added as of #12681
4584 #
4585 # By convention, the visit_unary() method
4586 # itself does not add an entry to the result map, and relies
4587 # upon either the inner expression creating a result map
4588 # entry, or if not, by creating a label here that produces
4589 # the result map entry. Where that happens is based on whether
4590 # or not the element immediately inside the unary is a
4591 # NamedColumn subclass or not.
4592 #
4593 # Now, this also impacts how the SELECT is written; if
4594 # we decide to generate a label here, we get the usual
4595 # "~(x+y) AS anon_1" thing in the columns clause. If we
4596 # don't, we don't get an AS at all, we get like
4597 # "~table.column".
4598 #
4599 # But here is the important thing as of modernish (like 1.4)
4600 # versions of SQLAlchemy - **whether or not the AS <label>
4601 # is present in the statement is not actually important**.
4602 # We target result columns **positionally** for a fully
4603 # compiled ``Select()`` object; before 1.4 we needed those
4604 # labels to match in cursor.description etc etc but now it
4605 # really doesn't matter.
4606 # So really, we could set render_with_label True in all cases.
4607 # Or we could just have visit_unary() populate the result map
4608 # in all cases.
4609 #
4610 # What we're doing here is strictly trying to not rock the
4611 # boat too much with when we do/don't render "AS label";
4612 # labels being present helps in the edge cases that we
4613 # "fall back" to named cursor.description matching, labels
4614 # not being present for columns keeps us from having awkward
4615 # phrases like "SELECT DISTINCT table.x AS x".
4616 render_with_label = (
4617 (
4618 # exception case to detect if we render "not boolean"
4619 # as "not <col>" for native boolean or "<col> = 1"
4620 # for non-native boolean. this is controlled by
4621 # visit_is_<true|false>_unary_operator
4622 column.operator
4623 in (operators.is_false, operators.is_true)
4624 and not self.dialect.supports_native_boolean
4625 )
4626 or column._wraps_unnamed_column()
4627 or asfrom
4628 )
4629 elif (
4630 # general class of expressions that don't have a SQL-column
4631 # addressible name. includes scalar selects, bind parameters,
4632 # SQL functions, others
4633 not isinstance(column, elements.NamedColumn)
4634 # deeper check that indicates there's no natural "name" to
4635 # this element, which accommodates for custom SQL constructs
4636 # that might have a ".name" attribute (but aren't SQL
4637 # functions) but are not implementing this more recently added
4638 # base class. in theory the "NamedColumn" check should be
4639 # enough, however here we seek to maintain legacy behaviors
4640 # as well.
4641 and column._non_anon_label is None
4642 ):
4643 render_with_label = True
4644 else:
4645 render_with_label = False
4646
4647 if render_with_label:
4648 if not fallback_label_name:
4649 # used by the RETURNING case right now. we generate it
4650 # here as 3rd party dialects may be referring to
4651 # _label_select_column method directly instead of the
4652 # just-added _label_returning_column method
4653 assert not column_is_repeated
4654 fallback_label_name = column._anon_name_label
4655
4656 fallback_label_name = (
4657 elements._truncated_label(fallback_label_name)
4658 if not isinstance(
4659 fallback_label_name, elements._truncated_label
4660 )
4661 else fallback_label_name
4662 )
4663
4664 result_expr = _CompileLabel(
4665 col_expr, fallback_label_name, alt_names=(proxy_name,)
4666 )
4667 else:
4668 result_expr = col_expr
4669
4670 column_clause_args.update(
4671 within_columns_clause=within_columns_clause,
4672 add_to_result_map=add_to_result_map,
4673 include_table=include_table,
4674 )
4675 return result_expr._compiler_dispatch(self, **column_clause_args)
4676
4677 def format_from_hint_text(self, sqltext, table, hint, iscrud):
4678 hinttext = self.get_from_hint_text(table, hint)
4679 if hinttext:
4680 sqltext += " " + hinttext
4681 return sqltext
4682
4683 def get_select_hint_text(self, byfroms):
4684 return None
4685
4686 def get_from_hint_text(
4687 self, table: FromClause, text: Optional[str]
4688 ) -> Optional[str]:
4689 return None
4690
4691 def get_crud_hint_text(self, table, text):
4692 return None
4693
4694 def get_statement_hint_text(self, hint_texts):
4695 return " ".join(hint_texts)
4696
4697 _default_stack_entry: _CompilerStackEntry
4698
4699 if not typing.TYPE_CHECKING:
4700 _default_stack_entry = util.immutabledict(
4701 [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
4702 )
4703
4704 def _display_froms_for_select(
4705 self, select_stmt, asfrom, lateral=False, **kw
4706 ):
4707 # utility method to help external dialects
4708 # get the correct from list for a select.
4709 # specifically the oracle dialect needs this feature
4710 # right now.
4711 toplevel = not self.stack
4712 entry = self._default_stack_entry if toplevel else self.stack[-1]
4713
4714 compile_state = select_stmt._compile_state_factory(select_stmt, self)
4715
4716 correlate_froms = entry["correlate_froms"]
4717 asfrom_froms = entry["asfrom_froms"]
4718
4719 if asfrom and not lateral:
4720 froms = compile_state._get_display_froms(
4721 explicit_correlate_froms=correlate_froms.difference(
4722 asfrom_froms
4723 ),
4724 implicit_correlate_froms=(),
4725 )
4726 else:
4727 froms = compile_state._get_display_froms(
4728 explicit_correlate_froms=correlate_froms,
4729 implicit_correlate_froms=asfrom_froms,
4730 )
4731 return froms
4732
4733 translate_select_structure: Any = None
4734 """if not ``None``, should be a callable which accepts ``(select_stmt,
4735 **kw)`` and returns a select object. this is used for structural changes
4736 mostly to accommodate for LIMIT/OFFSET schemes
4737
4738 """
4739
4740 def visit_select(
4741 self,
4742 select_stmt,
4743 asfrom=False,
4744 insert_into=False,
4745 fromhints=None,
4746 compound_index=None,
4747 select_wraps_for=None,
4748 lateral=False,
4749 from_linter=None,
4750 **kwargs,
4751 ):
4752 assert select_wraps_for is None, (
4753 "SQLAlchemy 1.4 requires use of "
4754 "the translate_select_structure hook for structural "
4755 "translations of SELECT objects"
4756 )
4757
4758 # initial setup of SELECT. the compile_state_factory may now
4759 # be creating a totally different SELECT from the one that was
4760 # passed in. for ORM use this will convert from an ORM-state
4761 # SELECT to a regular "Core" SELECT. other composed operations
4762 # such as computation of joins will be performed.
4763
4764 kwargs["within_columns_clause"] = False
4765
4766 compile_state = select_stmt._compile_state_factory(
4767 select_stmt, self, **kwargs
4768 )
4769 kwargs["ambiguous_table_name_map"] = (
4770 compile_state._ambiguous_table_name_map
4771 )
4772
4773 select_stmt = compile_state.statement
4774
4775 toplevel = not self.stack
4776
4777 if toplevel and not self.compile_state:
4778 self.compile_state = compile_state
4779
4780 is_embedded_select = compound_index is not None or insert_into
4781
4782 # translate step for Oracle, SQL Server which often need to
4783 # restructure the SELECT to allow for LIMIT/OFFSET and possibly
4784 # other conditions
4785 if self.translate_select_structure:
4786 new_select_stmt = self.translate_select_structure(
4787 select_stmt, asfrom=asfrom, **kwargs
4788 )
4789
4790 # if SELECT was restructured, maintain a link to the originals
4791 # and assemble a new compile state
4792 if new_select_stmt is not select_stmt:
4793 compile_state_wraps_for = compile_state
4794 select_wraps_for = select_stmt
4795 select_stmt = new_select_stmt
4796
4797 compile_state = select_stmt._compile_state_factory(
4798 select_stmt, self, **kwargs
4799 )
4800 select_stmt = compile_state.statement
4801
4802 entry = self._default_stack_entry if toplevel else self.stack[-1]
4803
4804 populate_result_map = need_column_expressions = (
4805 toplevel
4806 or entry.get("need_result_map_for_compound", False)
4807 or entry.get("need_result_map_for_nested", False)
4808 )
4809
4810 # indicates there is a CompoundSelect in play and we are not the
4811 # first select
4812 if compound_index:
4813 populate_result_map = False
4814
4815 # this was first proposed as part of #3372; however, it is not
4816 # reached in current tests and could possibly be an assertion
4817 # instead.
4818 if not populate_result_map and "add_to_result_map" in kwargs:
4819 del kwargs["add_to_result_map"]
4820
4821 froms = self._setup_select_stack(
4822 select_stmt, compile_state, entry, asfrom, lateral, compound_index
4823 )
4824
4825 column_clause_args = kwargs.copy()
4826 column_clause_args.update(
4827 {"within_label_clause": False, "within_columns_clause": False}
4828 )
4829
4830 text = "SELECT " # we're off to a good start !
4831
4832 if select_stmt._post_select_clause is not None:
4833 psc = self.process(select_stmt._post_select_clause, **kwargs)
4834 if psc is not None:
4835 text += psc + " "
4836
4837 if select_stmt._hints:
4838 hint_text, byfrom = self._setup_select_hints(select_stmt)
4839 if hint_text:
4840 text += hint_text + " "
4841 else:
4842 byfrom = None
4843
4844 if select_stmt._independent_ctes:
4845 self._dispatch_independent_ctes(select_stmt, kwargs)
4846
4847 if select_stmt._prefixes:
4848 text += self._generate_prefixes(
4849 select_stmt, select_stmt._prefixes, **kwargs
4850 )
4851
4852 text += self.get_select_precolumns(select_stmt, **kwargs)
4853
4854 if select_stmt._pre_columns_clause is not None:
4855 pcc = self.process(select_stmt._pre_columns_clause, **kwargs)
4856 if pcc is not None:
4857 text += pcc + " "
4858
4859 # the actual list of columns to print in the SELECT column list.
4860 inner_columns = [
4861 c
4862 for c in [
4863 self._label_select_column(
4864 select_stmt,
4865 column,
4866 populate_result_map,
4867 asfrom,
4868 column_clause_args,
4869 name=name,
4870 proxy_name=proxy_name,
4871 fallback_label_name=fallback_label_name,
4872 column_is_repeated=repeated,
4873 need_column_expressions=need_column_expressions,
4874 )
4875 for (
4876 name,
4877 proxy_name,
4878 fallback_label_name,
4879 column,
4880 repeated,
4881 ) in compile_state.columns_plus_names
4882 ]
4883 if c is not None
4884 ]
4885
4886 if populate_result_map and select_wraps_for is not None:
4887 # if this select was generated from translate_select,
4888 # rewrite the targeted columns in the result map
4889
4890 translate = dict(
4891 zip(
4892 [
4893 name
4894 for (
4895 key,
4896 proxy_name,
4897 fallback_label_name,
4898 name,
4899 repeated,
4900 ) in compile_state.columns_plus_names
4901 ],
4902 [
4903 name
4904 for (
4905 key,
4906 proxy_name,
4907 fallback_label_name,
4908 name,
4909 repeated,
4910 ) in compile_state_wraps_for.columns_plus_names
4911 ],
4912 )
4913 )
4914
4915 self._result_columns = [
4916 ResultColumnsEntry(
4917 key, name, tuple(translate.get(o, o) for o in obj), type_
4918 )
4919 for key, name, obj, type_ in self._result_columns
4920 ]
4921
4922 text = self._compose_select_body(
4923 text,
4924 select_stmt,
4925 compile_state,
4926 inner_columns,
4927 froms,
4928 byfrom,
4929 toplevel,
4930 kwargs,
4931 )
4932
4933 if select_stmt._post_body_clause is not None:
4934 pbc = self.process(select_stmt._post_body_clause, **kwargs)
4935 if pbc:
4936 text += " " + pbc
4937
4938 if select_stmt._statement_hints:
4939 per_dialect = [
4940 ht
4941 for (dialect_name, ht) in select_stmt._statement_hints
4942 if dialect_name in ("*", self.dialect.name)
4943 ]
4944 if per_dialect:
4945 text += " " + self.get_statement_hint_text(per_dialect)
4946
4947 # In compound query, CTEs are shared at the compound level
4948 if self.ctes and (not is_embedded_select or toplevel):
4949 nesting_level = len(self.stack) if not toplevel else None
4950 text = self._render_cte_clause(nesting_level=nesting_level) + text
4951
4952 if select_stmt._suffixes:
4953 text += " " + self._generate_prefixes(
4954 select_stmt, select_stmt._suffixes, **kwargs
4955 )
4956
4957 self.stack.pop(-1)
4958
4959 return text
4960
4961 def _setup_select_hints(
4962 self, select: Select[Unpack[TupleAny]]
4963 ) -> Tuple[str, _FromHintsType]:
4964 byfrom = {
4965 from_: hinttext
4966 % {"name": from_._compiler_dispatch(self, ashint=True)}
4967 for (from_, dialect), hinttext in select._hints.items()
4968 if dialect in ("*", self.dialect.name)
4969 }
4970 hint_text = self.get_select_hint_text(byfrom)
4971 return hint_text, byfrom
4972
4973 def _setup_select_stack(
4974 self, select, compile_state, entry, asfrom, lateral, compound_index
4975 ):
4976 correlate_froms = entry["correlate_froms"]
4977 asfrom_froms = entry["asfrom_froms"]
4978
4979 if compound_index == 0:
4980 entry["select_0"] = select
4981 elif compound_index:
4982 select_0 = entry["select_0"]
4983 numcols = len(select_0._all_selected_columns)
4984
4985 if len(compile_state.columns_plus_names) != numcols:
4986 raise exc.CompileError(
4987 "All selectables passed to "
4988 "CompoundSelect must have identical numbers of "
4989 "columns; select #%d has %d columns, select "
4990 "#%d has %d"
4991 % (
4992 1,
4993 numcols,
4994 compound_index + 1,
4995 len(select._all_selected_columns),
4996 )
4997 )
4998
4999 if asfrom and not lateral:
5000 froms = compile_state._get_display_froms(
5001 explicit_correlate_froms=correlate_froms.difference(
5002 asfrom_froms
5003 ),
5004 implicit_correlate_froms=(),
5005 )
5006 else:
5007 froms = compile_state._get_display_froms(
5008 explicit_correlate_froms=correlate_froms,
5009 implicit_correlate_froms=asfrom_froms,
5010 )
5011
5012 new_correlate_froms = set(_from_objects(*froms))
5013 all_correlate_froms = new_correlate_froms.union(correlate_froms)
5014
5015 new_entry: _CompilerStackEntry = {
5016 "asfrom_froms": new_correlate_froms,
5017 "correlate_froms": all_correlate_froms,
5018 "selectable": select,
5019 "compile_state": compile_state,
5020 }
5021 self.stack.append(new_entry)
5022
5023 return froms
5024
5025 def _compose_select_body(
5026 self,
5027 text,
5028 select,
5029 compile_state,
5030 inner_columns,
5031 froms,
5032 byfrom,
5033 toplevel,
5034 kwargs,
5035 ):
5036 text += ", ".join(inner_columns)
5037
5038 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
5039 from_linter = FromLinter({}, set())
5040 warn_linting = self.linting & WARN_LINTING
5041 if toplevel:
5042 self.from_linter = from_linter
5043 else:
5044 from_linter = None
5045 warn_linting = False
5046
5047 # adjust the whitespace for no inner columns, part of #9440,
5048 # so that a no-col SELECT comes out as "SELECT WHERE..." or
5049 # "SELECT FROM ...".
5050 # while it would be better to have built the SELECT starting string
5051 # without trailing whitespace first, then add whitespace only if inner
5052 # cols were present, this breaks compatibility with various custom
5053 # compilation schemes that are currently being tested.
5054 if not inner_columns:
5055 text = text.rstrip()
5056
5057 if froms:
5058 text += " \nFROM "
5059
5060 if select._hints:
5061 text += ", ".join(
5062 [
5063 f._compiler_dispatch(
5064 self,
5065 asfrom=True,
5066 fromhints=byfrom,
5067 from_linter=from_linter,
5068 **kwargs,
5069 )
5070 for f in froms
5071 ]
5072 )
5073 else:
5074 text += ", ".join(
5075 [
5076 f._compiler_dispatch(
5077 self,
5078 asfrom=True,
5079 from_linter=from_linter,
5080 **kwargs,
5081 )
5082 for f in froms
5083 ]
5084 )
5085 else:
5086 text += self.default_from()
5087
5088 if select._where_criteria:
5089 t = self._generate_delimited_and_list(
5090 select._where_criteria, from_linter=from_linter, **kwargs
5091 )
5092 if t:
5093 text += " \nWHERE " + t
5094
5095 if warn_linting:
5096 assert from_linter is not None
5097 from_linter.warn()
5098
5099 if select._group_by_clauses:
5100 text += self.group_by_clause(select, **kwargs)
5101
5102 if select._having_criteria:
5103 t = self._generate_delimited_and_list(
5104 select._having_criteria, **kwargs
5105 )
5106 if t:
5107 text += " \nHAVING " + t
5108
5109 if select._post_criteria_clause is not None:
5110 pcc = self.process(select._post_criteria_clause, **kwargs)
5111 if pcc is not None:
5112 text += " \n" + pcc
5113
5114 if select._order_by_clauses:
5115 text += self.order_by_clause(select, **kwargs)
5116
5117 if select._has_row_limiting_clause:
5118 text += self._row_limit_clause(select, **kwargs)
5119
5120 if select._for_update_arg is not None:
5121 text += self.for_update_clause(select, **kwargs)
5122
5123 return text
5124
5125 def _generate_prefixes(self, stmt, prefixes, **kw):
5126 clause = " ".join(
5127 prefix._compiler_dispatch(self, **kw)
5128 for prefix, dialect_name in prefixes
5129 if dialect_name in (None, "*") or dialect_name == self.dialect.name
5130 )
5131 if clause:
5132 clause += " "
5133 return clause
5134
5135 def _render_cte_clause(
5136 self,
5137 nesting_level=None,
5138 include_following_stack=False,
5139 ):
5140 """
5141 include_following_stack
5142 Also render the nesting CTEs on the next stack. Useful for
5143 SQL structures like UNION or INSERT that can wrap SELECT
5144 statements containing nesting CTEs.
5145 """
5146 if not self.ctes:
5147 return ""
5148
5149 ctes: MutableMapping[CTE, str]
5150
5151 if nesting_level and nesting_level > 1:
5152 ctes = util.OrderedDict()
5153 for cte in list(self.ctes.keys()):
5154 cte_level, cte_name, cte_opts = self.level_name_by_cte[
5155 cte._get_reference_cte()
5156 ]
5157 nesting = cte.nesting or cte_opts.nesting
5158 is_rendered_level = cte_level == nesting_level or (
5159 include_following_stack and cte_level == nesting_level + 1
5160 )
5161 if not (nesting and is_rendered_level):
5162 continue
5163
5164 ctes[cte] = self.ctes[cte]
5165
5166 else:
5167 ctes = self.ctes
5168
5169 if not ctes:
5170 return ""
5171 ctes_recursive = any([cte.recursive for cte in ctes])
5172
5173 cte_text = self.get_cte_preamble(ctes_recursive) + " "
5174 cte_text += ", \n".join([txt for txt in ctes.values()])
5175 cte_text += "\n "
5176
5177 if nesting_level and nesting_level > 1:
5178 for cte in list(ctes.keys()):
5179 cte_level, cte_name, cte_opts = self.level_name_by_cte[
5180 cte._get_reference_cte()
5181 ]
5182 del self.ctes[cte]
5183 del self.ctes_by_level_name[(cte_level, cte_name)]
5184 del self.level_name_by_cte[cte._get_reference_cte()]
5185
5186 return cte_text
5187
5188 def get_cte_preamble(self, recursive):
5189 if recursive:
5190 return "WITH RECURSIVE"
5191 else:
5192 return "WITH"
5193
5194 def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str:
5195 """Called when building a ``SELECT`` statement, position is just
5196 before column list.
5197
5198 """
5199 if select._distinct_on:
5200 util.warn_deprecated(
5201 "DISTINCT ON is currently supported only by the PostgreSQL "
5202 "dialect. Use of DISTINCT ON for other backends is currently "
5203 "silently ignored, however this usage is deprecated, and will "
5204 "raise CompileError in a future release for all backends "
5205 "that do not support this syntax.",
5206 version="1.4",
5207 )
5208 return "DISTINCT " if select._distinct else ""
5209
5210 def group_by_clause(self, select, **kw):
5211 """allow dialects to customize how GROUP BY is rendered."""
5212
5213 group_by = self._generate_delimited_list(
5214 select._group_by_clauses, OPERATORS[operators.comma_op], **kw
5215 )
5216 if group_by:
5217 return " GROUP BY " + group_by
5218 else:
5219 return ""
5220
5221 def order_by_clause(self, select, **kw):
5222 """allow dialects to customize how ORDER BY is rendered."""
5223
5224 order_by = self._generate_delimited_list(
5225 select._order_by_clauses, OPERATORS[operators.comma_op], **kw
5226 )
5227
5228 if order_by:
5229 return " ORDER BY " + order_by
5230 else:
5231 return ""
5232
5233 def for_update_clause(self, select, **kw):
5234 return " FOR UPDATE"
5235
5236 def returning_clause(
5237 self,
5238 stmt: UpdateBase,
5239 returning_cols: Sequence[_ColumnsClauseElement],
5240 *,
5241 populate_result_map: bool,
5242 **kw: Any,
5243 ) -> str:
5244 columns = [
5245 self._label_returning_column(
5246 stmt,
5247 column,
5248 populate_result_map,
5249 fallback_label_name=fallback_label_name,
5250 column_is_repeated=repeated,
5251 name=name,
5252 proxy_name=proxy_name,
5253 **kw,
5254 )
5255 for (
5256 name,
5257 proxy_name,
5258 fallback_label_name,
5259 column,
5260 repeated,
5261 ) in stmt._generate_columns_plus_names(
5262 True, cols=base._select_iterables(returning_cols)
5263 )
5264 ]
5265
5266 return "RETURNING " + ", ".join(columns)
5267
5268 def limit_clause(self, select, **kw):
5269 text = ""
5270 if select._limit_clause is not None:
5271 text += "\n LIMIT " + self.process(select._limit_clause, **kw)
5272 if select._offset_clause is not None:
5273 if select._limit_clause is None:
5274 text += "\n LIMIT -1"
5275 text += " OFFSET " + self.process(select._offset_clause, **kw)
5276 return text
5277
5278 def fetch_clause(
5279 self,
5280 select,
5281 fetch_clause=None,
5282 require_offset=False,
5283 use_literal_execute_for_simple_int=False,
5284 **kw,
5285 ):
5286 if fetch_clause is None:
5287 fetch_clause = select._fetch_clause
5288 fetch_clause_options = select._fetch_clause_options
5289 else:
5290 fetch_clause_options = {"percent": False, "with_ties": False}
5291
5292 text = ""
5293
5294 if select._offset_clause is not None:
5295 offset_clause = select._offset_clause
5296 if (
5297 use_literal_execute_for_simple_int
5298 and select._simple_int_clause(offset_clause)
5299 ):
5300 offset_clause = offset_clause.render_literal_execute()
5301 offset_str = self.process(offset_clause, **kw)
5302 text += "\n OFFSET %s ROWS" % offset_str
5303 elif require_offset:
5304 text += "\n OFFSET 0 ROWS"
5305
5306 if fetch_clause is not None:
5307 if (
5308 use_literal_execute_for_simple_int
5309 and select._simple_int_clause(fetch_clause)
5310 ):
5311 fetch_clause = fetch_clause.render_literal_execute()
5312 text += "\n FETCH FIRST %s%s ROWS %s" % (
5313 self.process(fetch_clause, **kw),
5314 " PERCENT" if fetch_clause_options["percent"] else "",
5315 "WITH TIES" if fetch_clause_options["with_ties"] else "ONLY",
5316 )
5317 return text
5318
5319 def visit_table(
5320 self,
5321 table,
5322 asfrom=False,
5323 iscrud=False,
5324 ashint=False,
5325 fromhints=None,
5326 use_schema=True,
5327 from_linter=None,
5328 ambiguous_table_name_map=None,
5329 enclosing_alias=None,
5330 **kwargs,
5331 ):
5332 if from_linter:
5333 from_linter.froms[table] = table.fullname
5334
5335 if asfrom or ashint:
5336 effective_schema = self.preparer.schema_for_object(table)
5337
5338 if use_schema and effective_schema:
5339 ret = (
5340 self.preparer.quote_schema(effective_schema)
5341 + "."
5342 + self.preparer.quote(table.name)
5343 )
5344 else:
5345 ret = self.preparer.quote(table.name)
5346
5347 if (
5348 (
5349 enclosing_alias is None
5350 or enclosing_alias.element is not table
5351 )
5352 and not effective_schema
5353 and ambiguous_table_name_map
5354 and table.name in ambiguous_table_name_map
5355 ):
5356 anon_name = self._truncated_identifier(
5357 "alias", ambiguous_table_name_map[table.name]
5358 )
5359
5360 ret = ret + self.get_render_as_alias_suffix(
5361 self.preparer.format_alias(None, anon_name)
5362 )
5363
5364 if fromhints and table in fromhints:
5365 ret = self.format_from_hint_text(
5366 ret, table, fromhints[table], iscrud
5367 )
5368 return ret
5369 else:
5370 return ""
5371
5372 def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
5373 if from_linter:
5374 from_linter.edges.update(
5375 itertools.product(
5376 _de_clone(join.left._from_objects),
5377 _de_clone(join.right._from_objects),
5378 )
5379 )
5380
5381 if join.full:
5382 join_type = " FULL OUTER JOIN "
5383 elif join.isouter:
5384 join_type = " LEFT OUTER JOIN "
5385 else:
5386 join_type = " JOIN "
5387 return (
5388 join.left._compiler_dispatch(
5389 self, asfrom=True, from_linter=from_linter, **kwargs
5390 )
5391 + join_type
5392 + join.right._compiler_dispatch(
5393 self, asfrom=True, from_linter=from_linter, **kwargs
5394 )
5395 + " ON "
5396 # TODO: likely need asfrom=True here?
5397 + join.onclause._compiler_dispatch(
5398 self, from_linter=from_linter, **kwargs
5399 )
5400 )
5401
5402 def _setup_crud_hints(self, stmt, table_text):
5403 dialect_hints = {
5404 table: hint_text
5405 for (table, dialect), hint_text in stmt._hints.items()
5406 if dialect in ("*", self.dialect.name)
5407 }
5408 if stmt.table in dialect_hints:
5409 table_text = self.format_from_hint_text(
5410 table_text, stmt.table, dialect_hints[stmt.table], True
5411 )
5412 return dialect_hints, table_text
5413
5414 # within the realm of "insertmanyvalues sentinel columns",
5415 # these lookups match different kinds of Column() configurations
5416 # to specific backend capabilities. they are broken into two
5417 # lookups, one for autoincrement columns and the other for non
5418 # autoincrement columns
5419 _sentinel_col_non_autoinc_lookup = util.immutabledict(
5420 {
5421 _SentinelDefaultCharacterization.CLIENTSIDE: (
5422 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5423 ),
5424 _SentinelDefaultCharacterization.SENTINEL_DEFAULT: (
5425 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5426 ),
5427 _SentinelDefaultCharacterization.NONE: (
5428 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5429 ),
5430 _SentinelDefaultCharacterization.IDENTITY: (
5431 InsertmanyvaluesSentinelOpts.IDENTITY
5432 ),
5433 _SentinelDefaultCharacterization.SEQUENCE: (
5434 InsertmanyvaluesSentinelOpts.SEQUENCE
5435 ),
5436 }
5437 )
5438 _sentinel_col_autoinc_lookup = _sentinel_col_non_autoinc_lookup.union(
5439 {
5440 _SentinelDefaultCharacterization.NONE: (
5441 InsertmanyvaluesSentinelOpts.AUTOINCREMENT
5442 ),
5443 }
5444 )
5445
5446 def _get_sentinel_column_for_table(
5447 self, table: Table
5448 ) -> Optional[Sequence[Column[Any]]]:
5449 """given a :class:`.Table`, return a usable sentinel column or
5450 columns for this dialect if any.
5451
5452 Return None if no sentinel columns could be identified, or raise an
5453 error if a column was marked as a sentinel explicitly but isn't
5454 compatible with this dialect.
5455
5456 """
5457
5458 sentinel_opts = self.dialect.insertmanyvalues_implicit_sentinel
5459 sentinel_characteristics = table._sentinel_column_characteristics
5460
5461 sent_cols = sentinel_characteristics.columns
5462
5463 if sent_cols is None:
5464 return None
5465
5466 if sentinel_characteristics.is_autoinc:
5467 bitmask = self._sentinel_col_autoinc_lookup.get(
5468 sentinel_characteristics.default_characterization, 0
5469 )
5470 else:
5471 bitmask = self._sentinel_col_non_autoinc_lookup.get(
5472 sentinel_characteristics.default_characterization, 0
5473 )
5474
5475 if sentinel_opts & bitmask:
5476 return sent_cols
5477
5478 if sentinel_characteristics.is_explicit:
5479 # a column was explicitly marked as insert_sentinel=True,
5480 # however it is not compatible with this dialect. they should
5481 # not indicate this column as a sentinel if they need to include
5482 # this dialect.
5483
5484 # TODO: do we want non-primary key explicit sentinel cols
5485 # that can gracefully degrade for some backends?
5486 # insert_sentinel="degrade" perhaps. not for the initial release.
5487 # I am hoping people are generally not dealing with this sentinel
5488 # business at all.
5489
5490 # if is_explicit is True, there will be only one sentinel column.
5491
5492 raise exc.InvalidRequestError(
5493 f"Column {sent_cols[0]} can't be explicitly "
5494 "marked as a sentinel column when using the "
5495 f"{self.dialect.name} dialect, as the "
5496 "particular type of default generation on this column is "
5497 "not currently compatible with this dialect's specific "
5498 f"INSERT..RETURNING syntax which can receive the "
5499 "server-generated value in "
5500 "a deterministic way. To remove this error, remove "
5501 "insert_sentinel=True from primary key autoincrement "
5502 "columns; these columns are automatically used as "
5503 "sentinels for supported dialects in any case."
5504 )
5505
5506 return None
5507
5508 def _deliver_insertmanyvalues_batches(
5509 self,
5510 statement: str,
5511 parameters: _DBAPIMultiExecuteParams,
5512 compiled_parameters: List[_MutableCoreSingleExecuteParams],
5513 generic_setinputsizes: Optional[_GenericSetInputSizesType],
5514 batch_size: int,
5515 sort_by_parameter_order: bool,
5516 schema_translate_map: Optional[SchemaTranslateMapType],
5517 ) -> Iterator[_InsertManyValuesBatch]:
5518 imv = self._insertmanyvalues
5519 assert imv is not None
5520
5521 if not imv.sentinel_param_keys:
5522 _sentinel_from_params = None
5523 else:
5524 _sentinel_from_params = operator.itemgetter(
5525 *imv.sentinel_param_keys
5526 )
5527
5528 lenparams = len(parameters)
5529 if imv.is_default_expr and not self.dialect.supports_default_metavalue:
5530 # backend doesn't support
5531 # INSERT INTO table (pk_col) VALUES (DEFAULT), (DEFAULT), ...
5532 # at the moment this is basically SQL Server due to
5533 # not being able to use DEFAULT for identity column
5534 # just yield out that many single statements! still
5535 # faster than a whole connection.execute() call ;)
5536 #
5537 # note we still are taking advantage of the fact that we know
5538 # we are using RETURNING. The generalized approach of fetching
5539 # cursor.lastrowid etc. still goes through the more heavyweight
5540 # "ExecutionContext per statement" system as it isn't usable
5541 # as a generic "RETURNING" approach
5542 use_row_at_a_time = True
5543 downgraded = False
5544 elif not self.dialect.supports_multivalues_insert or (
5545 sort_by_parameter_order
5546 and self._result_columns
5547 and (imv.sentinel_columns is None or imv.includes_upsert_behaviors)
5548 ):
5549 # deterministic order was requested and the compiler could
5550 # not organize sentinel columns for this dialect/statement.
5551 # use row at a time
5552 use_row_at_a_time = True
5553 downgraded = True
5554 else:
5555 use_row_at_a_time = False
5556 downgraded = False
5557
5558 if use_row_at_a_time:
5559 for batchnum, (param, compiled_param) in enumerate(
5560 cast(
5561 "Sequence[Tuple[_DBAPISingleExecuteParams, _MutableCoreSingleExecuteParams]]", # noqa: E501
5562 zip(parameters, compiled_parameters),
5563 ),
5564 1,
5565 ):
5566 yield _InsertManyValuesBatch(
5567 statement,
5568 param,
5569 generic_setinputsizes,
5570 [param],
5571 (
5572 [_sentinel_from_params(compiled_param)]
5573 if _sentinel_from_params
5574 else []
5575 ),
5576 1,
5577 batchnum,
5578 lenparams,
5579 sort_by_parameter_order,
5580 downgraded,
5581 )
5582 return
5583
5584 if schema_translate_map:
5585 rst = functools.partial(
5586 self.preparer._render_schema_translates,
5587 schema_translate_map=schema_translate_map,
5588 )
5589 else:
5590 rst = None
5591
5592 imv_single_values_expr = imv.single_values_expr
5593 if rst:
5594 imv_single_values_expr = rst(imv_single_values_expr)
5595
5596 executemany_values = f"({imv_single_values_expr})"
5597 statement = statement.replace(executemany_values, "__EXECMANY_TOKEN__")
5598
5599 # Use optional insertmanyvalues_max_parameters
5600 # to further shrink the batch size so that there are no more than
5601 # insertmanyvalues_max_parameters params.
5602 # Currently used by SQL Server, which limits statements to 2100 bound
5603 # parameters (actually 2099).
5604 max_params = self.dialect.insertmanyvalues_max_parameters
5605 if max_params:
5606 total_num_of_params = len(self.bind_names)
5607 num_params_per_batch = len(imv.insert_crud_params)
5608 num_params_outside_of_batch = (
5609 total_num_of_params - num_params_per_batch
5610 )
5611 batch_size = min(
5612 batch_size,
5613 (
5614 (max_params - num_params_outside_of_batch)
5615 // num_params_per_batch
5616 ),
5617 )
5618
5619 batches = cast("List[Sequence[Any]]", list(parameters))
5620 compiled_batches = cast(
5621 "List[Sequence[Any]]", list(compiled_parameters)
5622 )
5623
5624 processed_setinputsizes: Optional[_GenericSetInputSizesType] = None
5625 batchnum = 1
5626 total_batches = lenparams // batch_size + (
5627 1 if lenparams % batch_size else 0
5628 )
5629
5630 insert_crud_params = imv.insert_crud_params
5631 assert insert_crud_params is not None
5632
5633 if rst:
5634 insert_crud_params = [
5635 (col, key, rst(expr), st)
5636 for col, key, expr, st in insert_crud_params
5637 ]
5638
5639 escaped_bind_names: Mapping[str, str]
5640 expand_pos_lower_index = expand_pos_upper_index = 0
5641
5642 if not self.positional:
5643 if self.escaped_bind_names:
5644 escaped_bind_names = self.escaped_bind_names
5645 else:
5646 escaped_bind_names = {}
5647
5648 all_keys = set(parameters[0])
5649
5650 def apply_placeholders(keys, formatted):
5651 for key in keys:
5652 key = escaped_bind_names.get(key, key)
5653 formatted = formatted.replace(
5654 self.bindtemplate % {"name": key},
5655 self.bindtemplate
5656 % {"name": f"{key}__EXECMANY_INDEX__"},
5657 )
5658 return formatted
5659
5660 if imv.embed_values_counter:
5661 imv_values_counter = ", _IMV_VALUES_COUNTER"
5662 else:
5663 imv_values_counter = ""
5664 formatted_values_clause = f"""({', '.join(
5665 apply_placeholders(bind_keys, formatted)
5666 for _, _, formatted, bind_keys in insert_crud_params
5667 )}{imv_values_counter})"""
5668
5669 keys_to_replace = all_keys.intersection(
5670 escaped_bind_names.get(key, key)
5671 for _, _, _, bind_keys in insert_crud_params
5672 for key in bind_keys
5673 )
5674 base_parameters = {
5675 key: parameters[0][key]
5676 for key in all_keys.difference(keys_to_replace)
5677 }
5678 executemany_values_w_comma = ""
5679 else:
5680 formatted_values_clause = ""
5681 keys_to_replace = set()
5682 base_parameters = {}
5683
5684 if imv.embed_values_counter:
5685 executemany_values_w_comma = (
5686 f"({imv_single_values_expr}, _IMV_VALUES_COUNTER), "
5687 )
5688 else:
5689 executemany_values_w_comma = f"({imv_single_values_expr}), "
5690
5691 all_names_we_will_expand: Set[str] = set()
5692 for elem in imv.insert_crud_params:
5693 all_names_we_will_expand.update(elem[3])
5694
5695 # get the start and end position in a particular list
5696 # of parameters where we will be doing the "expanding".
5697 # statements can have params on either side or both sides,
5698 # given RETURNING and CTEs
5699 if all_names_we_will_expand:
5700 positiontup = self.positiontup
5701 assert positiontup is not None
5702
5703 all_expand_positions = {
5704 idx
5705 for idx, name in enumerate(positiontup)
5706 if name in all_names_we_will_expand
5707 }
5708 expand_pos_lower_index = min(all_expand_positions)
5709 expand_pos_upper_index = max(all_expand_positions) + 1
5710 assert (
5711 len(all_expand_positions)
5712 == expand_pos_upper_index - expand_pos_lower_index
5713 )
5714
5715 if self._numeric_binds:
5716 escaped = re.escape(self._numeric_binds_identifier_char)
5717 executemany_values_w_comma = re.sub(
5718 rf"{escaped}\d+", "%s", executemany_values_w_comma
5719 )
5720
5721 while batches:
5722 batch = batches[0:batch_size]
5723 compiled_batch = compiled_batches[0:batch_size]
5724
5725 batches[0:batch_size] = []
5726 compiled_batches[0:batch_size] = []
5727
5728 if batches:
5729 current_batch_size = batch_size
5730 else:
5731 current_batch_size = len(batch)
5732
5733 if generic_setinputsizes:
5734 # if setinputsizes is present, expand this collection to
5735 # suit the batch length as well
5736 # currently this will be mssql+pyodbc for internal dialects
5737 processed_setinputsizes = [
5738 (new_key, len_, typ)
5739 for new_key, len_, typ in (
5740 (f"{key}_{index}", len_, typ)
5741 for index in range(current_batch_size)
5742 for key, len_, typ in generic_setinputsizes
5743 )
5744 ]
5745
5746 replaced_parameters: Any
5747 if self.positional:
5748 num_ins_params = imv.num_positional_params_counted
5749
5750 batch_iterator: Iterable[Sequence[Any]]
5751 extra_params_left: Sequence[Any]
5752 extra_params_right: Sequence[Any]
5753
5754 if num_ins_params == len(batch[0]):
5755 extra_params_left = extra_params_right = ()
5756 batch_iterator = batch
5757 else:
5758 extra_params_left = batch[0][:expand_pos_lower_index]
5759 extra_params_right = batch[0][expand_pos_upper_index:]
5760 batch_iterator = (
5761 b[expand_pos_lower_index:expand_pos_upper_index]
5762 for b in batch
5763 )
5764
5765 if imv.embed_values_counter:
5766 expanded_values_string = (
5767 "".join(
5768 executemany_values_w_comma.replace(
5769 "_IMV_VALUES_COUNTER", str(i)
5770 )
5771 for i, _ in enumerate(batch)
5772 )
5773 )[:-2]
5774 else:
5775 expanded_values_string = (
5776 (executemany_values_w_comma * current_batch_size)
5777 )[:-2]
5778
5779 if self._numeric_binds and num_ins_params > 0:
5780 # numeric will always number the parameters inside of
5781 # VALUES (and thus order self.positiontup) to be higher
5782 # than non-VALUES parameters, no matter where in the
5783 # statement those non-VALUES parameters appear (this is
5784 # ensured in _process_numeric by numbering first all
5785 # params that are not in _values_bindparam)
5786 # therefore all extra params are always
5787 # on the left side and numbered lower than the VALUES
5788 # parameters
5789 assert not extra_params_right
5790
5791 start = expand_pos_lower_index + 1
5792 end = num_ins_params * (current_batch_size) + start
5793
5794 # need to format here, since statement may contain
5795 # unescaped %, while values_string contains just (%s, %s)
5796 positions = tuple(
5797 f"{self._numeric_binds_identifier_char}{i}"
5798 for i in range(start, end)
5799 )
5800 expanded_values_string = expanded_values_string % positions
5801
5802 replaced_statement = statement.replace(
5803 "__EXECMANY_TOKEN__", expanded_values_string
5804 )
5805
5806 replaced_parameters = tuple(
5807 itertools.chain.from_iterable(batch_iterator)
5808 )
5809
5810 replaced_parameters = (
5811 extra_params_left
5812 + replaced_parameters
5813 + extra_params_right
5814 )
5815
5816 else:
5817 replaced_values_clauses = []
5818 replaced_parameters = base_parameters.copy()
5819
5820 for i, param in enumerate(batch):
5821 fmv = formatted_values_clause.replace(
5822 "EXECMANY_INDEX__", str(i)
5823 )
5824 if imv.embed_values_counter:
5825 fmv = fmv.replace("_IMV_VALUES_COUNTER", str(i))
5826
5827 replaced_values_clauses.append(fmv)
5828 replaced_parameters.update(
5829 {f"{key}__{i}": param[key] for key in keys_to_replace}
5830 )
5831
5832 replaced_statement = statement.replace(
5833 "__EXECMANY_TOKEN__",
5834 ", ".join(replaced_values_clauses),
5835 )
5836
5837 yield _InsertManyValuesBatch(
5838 replaced_statement,
5839 replaced_parameters,
5840 processed_setinputsizes,
5841 batch,
5842 (
5843 [_sentinel_from_params(cb) for cb in compiled_batch]
5844 if _sentinel_from_params
5845 else []
5846 ),
5847 current_batch_size,
5848 batchnum,
5849 total_batches,
5850 sort_by_parameter_order,
5851 False,
5852 )
5853 batchnum += 1
5854
5855 def visit_insert(
5856 self, insert_stmt, visited_bindparam=None, visiting_cte=None, **kw
5857 ):
5858 compile_state = insert_stmt._compile_state_factory(
5859 insert_stmt, self, **kw
5860 )
5861 insert_stmt = compile_state.statement
5862
5863 if visiting_cte is not None:
5864 kw["visiting_cte"] = visiting_cte
5865 toplevel = False
5866 else:
5867 toplevel = not self.stack
5868
5869 if toplevel:
5870 self.isinsert = True
5871 if not self.dml_compile_state:
5872 self.dml_compile_state = compile_state
5873 if not self.compile_state:
5874 self.compile_state = compile_state
5875
5876 self.stack.append(
5877 {
5878 "correlate_froms": set(),
5879 "asfrom_froms": set(),
5880 "selectable": insert_stmt,
5881 }
5882 )
5883
5884 counted_bindparam = 0
5885
5886 # reset any incoming "visited_bindparam" collection
5887 visited_bindparam = None
5888
5889 # for positional, insertmanyvalues needs to know how many
5890 # bound parameters are in the VALUES sequence; there's no simple
5891 # rule because default expressions etc. can have zero or more
5892 # params inside them. After multiple attempts to figure this out,
5893 # this very simplistic "count after" works and is
5894 # likely the least amount of callcounts, though looks clumsy
5895 if self.positional and visiting_cte is None:
5896 # if we are inside a CTE, don't count parameters
5897 # here since they wont be for insertmanyvalues. keep
5898 # visited_bindparam at None so no counting happens.
5899 # see #9173
5900 visited_bindparam = []
5901
5902 crud_params_struct = crud._get_crud_params(
5903 self,
5904 insert_stmt,
5905 compile_state,
5906 toplevel,
5907 visited_bindparam=visited_bindparam,
5908 **kw,
5909 )
5910
5911 if self.positional and visited_bindparam is not None:
5912 counted_bindparam = len(visited_bindparam)
5913 if self._numeric_binds:
5914 if self._values_bindparam is not None:
5915 self._values_bindparam += visited_bindparam
5916 else:
5917 self._values_bindparam = visited_bindparam
5918
5919 crud_params_single = crud_params_struct.single_params
5920
5921 if (
5922 not crud_params_single
5923 and not self.dialect.supports_default_values
5924 and not self.dialect.supports_default_metavalue
5925 and not self.dialect.supports_empty_insert
5926 ):
5927 raise exc.CompileError(
5928 "The '%s' dialect with current database "
5929 "version settings does not support empty "
5930 "inserts." % self.dialect.name
5931 )
5932
5933 if compile_state._has_multi_parameters:
5934 if not self.dialect.supports_multivalues_insert:
5935 raise exc.CompileError(
5936 "The '%s' dialect with current database "
5937 "version settings does not support "
5938 "in-place multirow inserts." % self.dialect.name
5939 )
5940 elif (
5941 self.implicit_returning or insert_stmt._returning
5942 ) and insert_stmt._sort_by_parameter_order:
5943 raise exc.CompileError(
5944 "RETURNING cannot be determinstically sorted when "
5945 "using an INSERT which includes multi-row values()."
5946 )
5947 crud_params_single = crud_params_struct.single_params
5948 else:
5949 crud_params_single = crud_params_struct.single_params
5950
5951 preparer = self.preparer
5952 supports_default_values = self.dialect.supports_default_values
5953
5954 text = "INSERT "
5955
5956 if insert_stmt._prefixes:
5957 text += self._generate_prefixes(
5958 insert_stmt, insert_stmt._prefixes, **kw
5959 )
5960
5961 text += "INTO "
5962 table_text = preparer.format_table(insert_stmt.table)
5963
5964 if insert_stmt._hints:
5965 _, table_text = self._setup_crud_hints(insert_stmt, table_text)
5966
5967 if insert_stmt._independent_ctes:
5968 self._dispatch_independent_ctes(insert_stmt, kw)
5969
5970 text += table_text
5971
5972 if crud_params_single or not supports_default_values:
5973 text += " (%s)" % ", ".join(
5974 [expr for _, expr, _, _ in crud_params_single]
5975 )
5976
5977 # look for insertmanyvalues attributes that would have been configured
5978 # by crud.py as it scanned through the columns to be part of the
5979 # INSERT
5980 use_insertmanyvalues = crud_params_struct.use_insertmanyvalues
5981 named_sentinel_params: Optional[Sequence[str]] = None
5982 add_sentinel_cols = None
5983 implicit_sentinel = False
5984
5985 returning_cols = self.implicit_returning or insert_stmt._returning
5986 if returning_cols:
5987 add_sentinel_cols = crud_params_struct.use_sentinel_columns
5988 if add_sentinel_cols is not None:
5989 assert use_insertmanyvalues
5990
5991 # search for the sentinel column explicitly present
5992 # in the INSERT columns list, and additionally check that
5993 # this column has a bound parameter name set up that's in the
5994 # parameter list. If both of these cases are present, it means
5995 # we will have a client side value for the sentinel in each
5996 # parameter set.
5997
5998 _params_by_col = {
5999 col: param_names
6000 for col, _, _, param_names in crud_params_single
6001 }
6002 named_sentinel_params = []
6003 for _add_sentinel_col in add_sentinel_cols:
6004 if _add_sentinel_col not in _params_by_col:
6005 named_sentinel_params = None
6006 break
6007 param_name = self._within_exec_param_key_getter(
6008 _add_sentinel_col
6009 )
6010 if param_name not in _params_by_col[_add_sentinel_col]:
6011 named_sentinel_params = None
6012 break
6013 named_sentinel_params.append(param_name)
6014
6015 if named_sentinel_params is None:
6016 # if we are not going to have a client side value for
6017 # the sentinel in the parameter set, that means it's
6018 # an autoincrement, an IDENTITY, or a server-side SQL
6019 # expression like nextval('seqname'). So this is
6020 # an "implicit" sentinel; we will look for it in
6021 # RETURNING
6022 # only, and then sort on it. For this case on PG,
6023 # SQL Server we have to use a special INSERT form
6024 # that guarantees the server side function lines up with
6025 # the entries in the VALUES.
6026 if (
6027 self.dialect.insertmanyvalues_implicit_sentinel
6028 & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT
6029 ):
6030 implicit_sentinel = True
6031 else:
6032 # here, we are not using a sentinel at all
6033 # and we are likely the SQLite dialect.
6034 # The first add_sentinel_col that we have should not
6035 # be marked as "insert_sentinel=True". if it was,
6036 # an error should have been raised in
6037 # _get_sentinel_column_for_table.
6038 assert not add_sentinel_cols[0]._insert_sentinel, (
6039 "sentinel selection rules should have prevented "
6040 "us from getting here for this dialect"
6041 )
6042
6043 # always put the sentinel columns last. even if they are
6044 # in the returning list already, they will be there twice
6045 # then.
6046 returning_cols = list(returning_cols) + list(add_sentinel_cols)
6047
6048 returning_clause = self.returning_clause(
6049 insert_stmt,
6050 returning_cols,
6051 populate_result_map=toplevel,
6052 )
6053
6054 if self.returning_precedes_values:
6055 text += " " + returning_clause
6056
6057 else:
6058 returning_clause = None
6059
6060 if insert_stmt.select is not None:
6061 # placed here by crud.py
6062 select_text = self.process(
6063 self.stack[-1]["insert_from_select"], insert_into=True, **kw
6064 )
6065
6066 if self.ctes and self.dialect.cte_follows_insert:
6067 nesting_level = len(self.stack) if not toplevel else None
6068 text += " %s%s" % (
6069 self._render_cte_clause(
6070 nesting_level=nesting_level,
6071 include_following_stack=True,
6072 ),
6073 select_text,
6074 )
6075 else:
6076 text += " %s" % select_text
6077 elif not crud_params_single and supports_default_values:
6078 text += " DEFAULT VALUES"
6079 if use_insertmanyvalues:
6080 self._insertmanyvalues = _InsertManyValues(
6081 True,
6082 self.dialect.default_metavalue_token,
6083 cast(
6084 "List[crud._CrudParamElementStr]", crud_params_single
6085 ),
6086 counted_bindparam,
6087 sort_by_parameter_order=(
6088 insert_stmt._sort_by_parameter_order
6089 ),
6090 includes_upsert_behaviors=(
6091 insert_stmt._post_values_clause is not None
6092 ),
6093 sentinel_columns=add_sentinel_cols,
6094 num_sentinel_columns=(
6095 len(add_sentinel_cols) if add_sentinel_cols else 0
6096 ),
6097 implicit_sentinel=implicit_sentinel,
6098 )
6099 elif compile_state._has_multi_parameters:
6100 text += " VALUES %s" % (
6101 ", ".join(
6102 "(%s)"
6103 % (", ".join(value for _, _, value, _ in crud_param_set))
6104 for crud_param_set in crud_params_struct.all_multi_params
6105 ),
6106 )
6107 else:
6108 insert_single_values_expr = ", ".join(
6109 [
6110 value
6111 for _, _, value, _ in cast(
6112 "List[crud._CrudParamElementStr]",
6113 crud_params_single,
6114 )
6115 ]
6116 )
6117
6118 if use_insertmanyvalues:
6119 if (
6120 implicit_sentinel
6121 and (
6122 self.dialect.insertmanyvalues_implicit_sentinel
6123 & InsertmanyvaluesSentinelOpts.USE_INSERT_FROM_SELECT
6124 )
6125 # this is checking if we have
6126 # INSERT INTO table (id) VALUES (DEFAULT).
6127 and not (crud_params_struct.is_default_metavalue_only)
6128 ):
6129 # if we have a sentinel column that is server generated,
6130 # then for selected backends render the VALUES list as a
6131 # subquery. This is the orderable form supported by
6132 # PostgreSQL and SQL Server.
6133 embed_sentinel_value = True
6134
6135 render_bind_casts = (
6136 self.dialect.insertmanyvalues_implicit_sentinel
6137 & InsertmanyvaluesSentinelOpts.RENDER_SELECT_COL_CASTS
6138 )
6139
6140 colnames = ", ".join(
6141 f"p{i}" for i, _ in enumerate(crud_params_single)
6142 )
6143
6144 if render_bind_casts:
6145 # render casts for the SELECT list. For PG, we are
6146 # already rendering bind casts in the parameter list,
6147 # selectively for the more "tricky" types like ARRAY.
6148 # however, even for the "easy" types, if the parameter
6149 # is NULL for every entry, PG gives up and says
6150 # "it must be TEXT", which fails for other easy types
6151 # like ints. So we cast on this side too.
6152 colnames_w_cast = ", ".join(
6153 self.render_bind_cast(
6154 col.type,
6155 col.type._unwrapped_dialect_impl(self.dialect),
6156 f"p{i}",
6157 )
6158 for i, (col, *_) in enumerate(crud_params_single)
6159 )
6160 else:
6161 colnames_w_cast = colnames
6162
6163 text += (
6164 f" SELECT {colnames_w_cast} FROM "
6165 f"(VALUES ({insert_single_values_expr})) "
6166 f"AS imp_sen({colnames}, sen_counter) "
6167 "ORDER BY sen_counter"
6168 )
6169 else:
6170 # otherwise, if no sentinel or backend doesn't support
6171 # orderable subquery form, use a plain VALUES list
6172 embed_sentinel_value = False
6173 text += f" VALUES ({insert_single_values_expr})"
6174
6175 self._insertmanyvalues = _InsertManyValues(
6176 is_default_expr=False,
6177 single_values_expr=insert_single_values_expr,
6178 insert_crud_params=cast(
6179 "List[crud._CrudParamElementStr]",
6180 crud_params_single,
6181 ),
6182 num_positional_params_counted=counted_bindparam,
6183 sort_by_parameter_order=(
6184 insert_stmt._sort_by_parameter_order
6185 ),
6186 includes_upsert_behaviors=(
6187 insert_stmt._post_values_clause is not None
6188 ),
6189 sentinel_columns=add_sentinel_cols,
6190 num_sentinel_columns=(
6191 len(add_sentinel_cols) if add_sentinel_cols else 0
6192 ),
6193 sentinel_param_keys=named_sentinel_params,
6194 implicit_sentinel=implicit_sentinel,
6195 embed_values_counter=embed_sentinel_value,
6196 )
6197
6198 else:
6199 text += f" VALUES ({insert_single_values_expr})"
6200
6201 if insert_stmt._post_values_clause is not None:
6202 post_values_clause = self.process(
6203 insert_stmt._post_values_clause, **kw
6204 )
6205 if post_values_clause:
6206 text += " " + post_values_clause
6207
6208 if returning_clause and not self.returning_precedes_values:
6209 text += " " + returning_clause
6210
6211 if self.ctes and not self.dialect.cte_follows_insert:
6212 nesting_level = len(self.stack) if not toplevel else None
6213 text = (
6214 self._render_cte_clause(
6215 nesting_level=nesting_level,
6216 include_following_stack=True,
6217 )
6218 + text
6219 )
6220
6221 self.stack.pop(-1)
6222
6223 return text
6224
6225 def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
6226 """Provide a hook to override the initial table clause
6227 in an UPDATE statement.
6228
6229 MySQL overrides this.
6230
6231 """
6232 kw["asfrom"] = True
6233 return from_table._compiler_dispatch(self, iscrud=True, **kw)
6234
6235 def update_from_clause(
6236 self, update_stmt, from_table, extra_froms, from_hints, **kw
6237 ):
6238 """Provide a hook to override the generation of an
6239 UPDATE..FROM clause.
6240 MySQL and MSSQL override this.
6241 """
6242 raise NotImplementedError(
6243 "This backend does not support multiple-table "
6244 "criteria within UPDATE"
6245 )
6246
6247 def update_post_criteria_clause(
6248 self, update_stmt: Update, **kw: Any
6249 ) -> Optional[str]:
6250 """provide a hook to override generation after the WHERE criteria
6251 in an UPDATE statement
6252
6253 .. versionadded:: 2.1
6254
6255 """
6256 if update_stmt._post_criteria_clause is not None:
6257 return self.process(
6258 update_stmt._post_criteria_clause,
6259 **kw,
6260 )
6261 else:
6262 return None
6263
6264 def delete_post_criteria_clause(
6265 self, delete_stmt: Delete, **kw: Any
6266 ) -> Optional[str]:
6267 """provide a hook to override generation after the WHERE criteria
6268 in a DELETE statement
6269
6270 .. versionadded:: 2.1
6271
6272 """
6273 if delete_stmt._post_criteria_clause is not None:
6274 return self.process(
6275 delete_stmt._post_criteria_clause,
6276 **kw,
6277 )
6278 else:
6279 return None
6280
6281 def visit_update(
6282 self,
6283 update_stmt: Update,
6284 visiting_cte: Optional[CTE] = None,
6285 **kw: Any,
6286 ) -> str:
6287 compile_state = update_stmt._compile_state_factory(
6288 update_stmt, self, **kw
6289 )
6290 if TYPE_CHECKING:
6291 assert isinstance(compile_state, UpdateDMLState)
6292 update_stmt = compile_state.statement # type: ignore[assignment]
6293
6294 if visiting_cte is not None:
6295 kw["visiting_cte"] = visiting_cte
6296 toplevel = False
6297 else:
6298 toplevel = not self.stack
6299
6300 if toplevel:
6301 self.isupdate = True
6302 if not self.dml_compile_state:
6303 self.dml_compile_state = compile_state
6304 if not self.compile_state:
6305 self.compile_state = compile_state
6306
6307 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
6308 from_linter = FromLinter({}, set())
6309 warn_linting = self.linting & WARN_LINTING
6310 if toplevel:
6311 self.from_linter = from_linter
6312 else:
6313 from_linter = None
6314 warn_linting = False
6315
6316 extra_froms = compile_state._extra_froms
6317 is_multitable = bool(extra_froms)
6318
6319 if is_multitable:
6320 # main table might be a JOIN
6321 main_froms = set(_from_objects(update_stmt.table))
6322 render_extra_froms = [
6323 f for f in extra_froms if f not in main_froms
6324 ]
6325 correlate_froms = main_froms.union(extra_froms)
6326 else:
6327 render_extra_froms = []
6328 correlate_froms = {update_stmt.table}
6329
6330 self.stack.append(
6331 {
6332 "correlate_froms": correlate_froms,
6333 "asfrom_froms": correlate_froms,
6334 "selectable": update_stmt,
6335 }
6336 )
6337
6338 text = "UPDATE "
6339
6340 if update_stmt._prefixes:
6341 text += self._generate_prefixes(
6342 update_stmt, update_stmt._prefixes, **kw
6343 )
6344
6345 table_text = self.update_tables_clause(
6346 update_stmt,
6347 update_stmt.table,
6348 render_extra_froms,
6349 from_linter=from_linter,
6350 **kw,
6351 )
6352 crud_params_struct = crud._get_crud_params(
6353 self, update_stmt, compile_state, toplevel, **kw
6354 )
6355 crud_params = crud_params_struct.single_params
6356
6357 if update_stmt._hints:
6358 dialect_hints, table_text = self._setup_crud_hints(
6359 update_stmt, table_text
6360 )
6361 else:
6362 dialect_hints = None
6363
6364 if update_stmt._independent_ctes:
6365 self._dispatch_independent_ctes(update_stmt, kw)
6366
6367 text += table_text
6368
6369 text += " SET "
6370 text += ", ".join(
6371 expr + "=" + value
6372 for _, expr, value, _ in cast(
6373 "List[Tuple[Any, str, str, Any]]", crud_params
6374 )
6375 )
6376
6377 if self.implicit_returning or update_stmt._returning:
6378 if self.returning_precedes_values:
6379 text += " " + self.returning_clause(
6380 update_stmt,
6381 self.implicit_returning or update_stmt._returning,
6382 populate_result_map=toplevel,
6383 )
6384
6385 if extra_froms:
6386 extra_from_text = self.update_from_clause(
6387 update_stmt,
6388 update_stmt.table,
6389 render_extra_froms,
6390 dialect_hints,
6391 from_linter=from_linter,
6392 **kw,
6393 )
6394 if extra_from_text:
6395 text += " " + extra_from_text
6396
6397 if update_stmt._where_criteria:
6398 t = self._generate_delimited_and_list(
6399 update_stmt._where_criteria, from_linter=from_linter, **kw
6400 )
6401 if t:
6402 text += " WHERE " + t
6403
6404 ulc = self.update_post_criteria_clause(
6405 update_stmt, from_linter=from_linter, **kw
6406 )
6407 if ulc:
6408 text += " " + ulc
6409
6410 if (
6411 self.implicit_returning or update_stmt._returning
6412 ) and not self.returning_precedes_values:
6413 text += " " + self.returning_clause(
6414 update_stmt,
6415 self.implicit_returning or update_stmt._returning,
6416 populate_result_map=toplevel,
6417 )
6418
6419 if self.ctes:
6420 nesting_level = len(self.stack) if not toplevel else None
6421 text = self._render_cte_clause(nesting_level=nesting_level) + text
6422
6423 if warn_linting:
6424 assert from_linter is not None
6425 from_linter.warn(stmt_type="UPDATE")
6426
6427 self.stack.pop(-1)
6428
6429 return text # type: ignore[no-any-return]
6430
6431 def delete_extra_from_clause(
6432 self, delete_stmt, from_table, extra_froms, from_hints, **kw
6433 ):
6434 """Provide a hook to override the generation of an
6435 DELETE..FROM clause.
6436
6437 This can be used to implement DELETE..USING for example.
6438
6439 MySQL and MSSQL override this.
6440
6441 """
6442 raise NotImplementedError(
6443 "This backend does not support multiple-table "
6444 "criteria within DELETE"
6445 )
6446
6447 def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
6448 return from_table._compiler_dispatch(
6449 self, asfrom=True, iscrud=True, **kw
6450 )
6451
6452 def visit_delete(self, delete_stmt, visiting_cte=None, **kw):
6453 compile_state = delete_stmt._compile_state_factory(
6454 delete_stmt, self, **kw
6455 )
6456 delete_stmt = compile_state.statement
6457
6458 if visiting_cte is not None:
6459 kw["visiting_cte"] = visiting_cte
6460 toplevel = False
6461 else:
6462 toplevel = not self.stack
6463
6464 if toplevel:
6465 self.isdelete = True
6466 if not self.dml_compile_state:
6467 self.dml_compile_state = compile_state
6468 if not self.compile_state:
6469 self.compile_state = compile_state
6470
6471 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
6472 from_linter = FromLinter({}, set())
6473 warn_linting = self.linting & WARN_LINTING
6474 if toplevel:
6475 self.from_linter = from_linter
6476 else:
6477 from_linter = None
6478 warn_linting = False
6479
6480 extra_froms = compile_state._extra_froms
6481
6482 correlate_froms = {delete_stmt.table}.union(extra_froms)
6483 self.stack.append(
6484 {
6485 "correlate_froms": correlate_froms,
6486 "asfrom_froms": correlate_froms,
6487 "selectable": delete_stmt,
6488 }
6489 )
6490
6491 text = "DELETE "
6492
6493 if delete_stmt._prefixes:
6494 text += self._generate_prefixes(
6495 delete_stmt, delete_stmt._prefixes, **kw
6496 )
6497
6498 text += "FROM "
6499
6500 try:
6501 table_text = self.delete_table_clause(
6502 delete_stmt,
6503 delete_stmt.table,
6504 extra_froms,
6505 from_linter=from_linter,
6506 )
6507 except TypeError:
6508 # anticipate 3rd party dialects that don't include **kw
6509 # TODO: remove in 2.1
6510 table_text = self.delete_table_clause(
6511 delete_stmt, delete_stmt.table, extra_froms
6512 )
6513 if from_linter:
6514 _ = self.process(delete_stmt.table, from_linter=from_linter)
6515
6516 crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
6517
6518 if delete_stmt._hints:
6519 dialect_hints, table_text = self._setup_crud_hints(
6520 delete_stmt, table_text
6521 )
6522 else:
6523 dialect_hints = None
6524
6525 if delete_stmt._independent_ctes:
6526 self._dispatch_independent_ctes(delete_stmt, kw)
6527
6528 text += table_text
6529
6530 if (
6531 self.implicit_returning or delete_stmt._returning
6532 ) and self.returning_precedes_values:
6533 text += " " + self.returning_clause(
6534 delete_stmt,
6535 self.implicit_returning or delete_stmt._returning,
6536 populate_result_map=toplevel,
6537 )
6538
6539 if extra_froms:
6540 extra_from_text = self.delete_extra_from_clause(
6541 delete_stmt,
6542 delete_stmt.table,
6543 extra_froms,
6544 dialect_hints,
6545 from_linter=from_linter,
6546 **kw,
6547 )
6548 if extra_from_text:
6549 text += " " + extra_from_text
6550
6551 if delete_stmt._where_criteria:
6552 t = self._generate_delimited_and_list(
6553 delete_stmt._where_criteria, from_linter=from_linter, **kw
6554 )
6555 if t:
6556 text += " WHERE " + t
6557
6558 dlc = self.delete_post_criteria_clause(
6559 delete_stmt, from_linter=from_linter, **kw
6560 )
6561 if dlc:
6562 text += " " + dlc
6563
6564 if (
6565 self.implicit_returning or delete_stmt._returning
6566 ) and not self.returning_precedes_values:
6567 text += " " + self.returning_clause(
6568 delete_stmt,
6569 self.implicit_returning or delete_stmt._returning,
6570 populate_result_map=toplevel,
6571 )
6572
6573 if self.ctes:
6574 nesting_level = len(self.stack) if not toplevel else None
6575 text = self._render_cte_clause(nesting_level=nesting_level) + text
6576
6577 if warn_linting:
6578 assert from_linter is not None
6579 from_linter.warn(stmt_type="DELETE")
6580
6581 self.stack.pop(-1)
6582
6583 return text
6584
6585 def visit_savepoint(self, savepoint_stmt, **kw):
6586 return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
6587
6588 def visit_rollback_to_savepoint(self, savepoint_stmt, **kw):
6589 return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
6590 savepoint_stmt
6591 )
6592
6593 def visit_release_savepoint(self, savepoint_stmt, **kw):
6594 return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
6595 savepoint_stmt
6596 )
6597
6598
6599class StrSQLCompiler(SQLCompiler):
6600 """A :class:`.SQLCompiler` subclass which allows a small selection
6601 of non-standard SQL features to render into a string value.
6602
6603 The :class:`.StrSQLCompiler` is invoked whenever a Core expression
6604 element is directly stringified without calling upon the
6605 :meth:`_expression.ClauseElement.compile` method.
6606 It can render a limited set
6607 of non-standard SQL constructs to assist in basic stringification,
6608 however for more substantial custom or dialect-specific SQL constructs,
6609 it will be necessary to make use of
6610 :meth:`_expression.ClauseElement.compile`
6611 directly.
6612
6613 .. seealso::
6614
6615 :ref:`faq_sql_expression_string`
6616
6617 """
6618
6619 def _fallback_column_name(self, column):
6620 return "<name unknown>"
6621
6622 @util.preload_module("sqlalchemy.engine.url")
6623 def visit_unsupported_compilation(self, element, err, **kw):
6624 if element.stringify_dialect != "default":
6625 url = util.preloaded.engine_url
6626 dialect = url.URL.create(element.stringify_dialect).get_dialect()()
6627
6628 compiler = dialect.statement_compiler(
6629 dialect, None, _supporting_against=self
6630 )
6631 if not isinstance(compiler, StrSQLCompiler):
6632 return compiler.process(element, **kw)
6633
6634 return super().visit_unsupported_compilation(element, err)
6635
6636 def visit_getitem_binary(self, binary, operator, **kw):
6637 return "%s[%s]" % (
6638 self.process(binary.left, **kw),
6639 self.process(binary.right, **kw),
6640 )
6641
6642 def visit_json_getitem_op_binary(self, binary, operator, **kw):
6643 return self.visit_getitem_binary(binary, operator, **kw)
6644
6645 def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
6646 return self.visit_getitem_binary(binary, operator, **kw)
6647
6648 def visit_sequence(self, sequence, **kw):
6649 return (
6650 f"<next sequence value: {self.preparer.format_sequence(sequence)}>"
6651 )
6652
6653 def returning_clause(
6654 self,
6655 stmt: UpdateBase,
6656 returning_cols: Sequence[_ColumnsClauseElement],
6657 *,
6658 populate_result_map: bool,
6659 **kw: Any,
6660 ) -> str:
6661 columns = [
6662 self._label_select_column(None, c, True, False, {})
6663 for c in base._select_iterables(returning_cols)
6664 ]
6665 return "RETURNING " + ", ".join(columns)
6666
6667 def update_from_clause(
6668 self, update_stmt, from_table, extra_froms, from_hints, **kw
6669 ):
6670 kw["asfrom"] = True
6671 return "FROM " + ", ".join(
6672 t._compiler_dispatch(self, fromhints=from_hints, **kw)
6673 for t in extra_froms
6674 )
6675
6676 def delete_extra_from_clause(
6677 self, delete_stmt, from_table, extra_froms, from_hints, **kw
6678 ):
6679 kw["asfrom"] = True
6680 return ", " + ", ".join(
6681 t._compiler_dispatch(self, fromhints=from_hints, **kw)
6682 for t in extra_froms
6683 )
6684
6685 def visit_empty_set_expr(self, element_types, **kw):
6686 return "SELECT 1 WHERE 1!=1"
6687
6688 def get_from_hint_text(self, table, text):
6689 return "[%s]" % text
6690
6691 def visit_regexp_match_op_binary(self, binary, operator, **kw):
6692 return self._generate_generic_binary(binary, " <regexp> ", **kw)
6693
6694 def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
6695 return self._generate_generic_binary(binary, " <not regexp> ", **kw)
6696
6697 def visit_regexp_replace_op_binary(self, binary, operator, **kw):
6698 return "<regexp replace>(%s, %s)" % (
6699 binary.left._compiler_dispatch(self, **kw),
6700 binary.right._compiler_dispatch(self, **kw),
6701 )
6702
6703 def visit_try_cast(self, cast, **kwargs):
6704 return "TRY_CAST(%s AS %s)" % (
6705 cast.clause._compiler_dispatch(self, **kwargs),
6706 cast.typeclause._compiler_dispatch(self, **kwargs),
6707 )
6708
6709
6710class DDLCompiler(Compiled):
6711 is_ddl = True
6712
6713 if TYPE_CHECKING:
6714
6715 def __init__(
6716 self,
6717 dialect: Dialect,
6718 statement: ExecutableDDLElement,
6719 schema_translate_map: Optional[SchemaTranslateMapType] = ...,
6720 render_schema_translate: bool = ...,
6721 compile_kwargs: Mapping[str, Any] = ...,
6722 ): ...
6723
6724 @util.ro_memoized_property
6725 def sql_compiler(self) -> SQLCompiler:
6726 return self.dialect.statement_compiler(
6727 self.dialect, None, schema_translate_map=self.schema_translate_map
6728 )
6729
6730 @util.memoized_property
6731 def type_compiler(self):
6732 return self.dialect.type_compiler_instance
6733
6734 def construct_params(
6735 self,
6736 params: Optional[_CoreSingleExecuteParams] = None,
6737 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
6738 escape_names: bool = True,
6739 ) -> Optional[_MutableCoreSingleExecuteParams]:
6740 return None
6741
6742 def visit_ddl(self, ddl, **kwargs):
6743 # table events can substitute table and schema name
6744 context = ddl.context
6745 if isinstance(ddl.target, schema.Table):
6746 context = context.copy()
6747
6748 preparer = self.preparer
6749 path = preparer.format_table_seq(ddl.target)
6750 if len(path) == 1:
6751 table, sch = path[0], ""
6752 else:
6753 table, sch = path[-1], path[0]
6754
6755 context.setdefault("table", table)
6756 context.setdefault("schema", sch)
6757 context.setdefault("fullname", preparer.format_table(ddl.target))
6758
6759 return self.sql_compiler.post_process_text(ddl.statement % context)
6760
6761 def visit_create_schema(self, create, **kw):
6762 text = "CREATE SCHEMA "
6763 if create.if_not_exists:
6764 text += "IF NOT EXISTS "
6765 return text + self.preparer.format_schema(create.element)
6766
6767 def visit_drop_schema(self, drop, **kw):
6768 text = "DROP SCHEMA "
6769 if drop.if_exists:
6770 text += "IF EXISTS "
6771 text += self.preparer.format_schema(drop.element)
6772 if drop.cascade:
6773 text += " CASCADE"
6774 return text
6775
6776 def visit_create_table(self, create, **kw):
6777 table = create.element
6778 preparer = self.preparer
6779
6780 text = "\nCREATE "
6781 if table._prefixes:
6782 text += " ".join(table._prefixes) + " "
6783
6784 text += "TABLE "
6785 if create.if_not_exists:
6786 text += "IF NOT EXISTS "
6787
6788 text += preparer.format_table(table) + " "
6789
6790 create_table_suffix = self.create_table_suffix(table)
6791 if create_table_suffix:
6792 text += create_table_suffix + " "
6793
6794 text += "("
6795
6796 separator = "\n"
6797
6798 # if only one primary key, specify it along with the column
6799 first_pk = False
6800 for create_column in create.columns:
6801 column = create_column.element
6802 try:
6803 processed = self.process(
6804 create_column, first_pk=column.primary_key and not first_pk
6805 )
6806 if processed is not None:
6807 text += separator
6808 separator = ", \n"
6809 text += "\t" + processed
6810 if column.primary_key:
6811 first_pk = True
6812 except exc.CompileError as ce:
6813 raise exc.CompileError(
6814 "(in table '%s', column '%s'): %s"
6815 % (table.description, column.name, ce.args[0])
6816 ) from ce
6817
6818 const = self.create_table_constraints(
6819 table,
6820 _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
6821 )
6822 if const:
6823 text += separator + "\t" + const
6824
6825 text += "\n)%s\n\n" % self.post_create_table(table)
6826 return text
6827
6828 def visit_create_column(self, create, first_pk=False, **kw):
6829 column = create.element
6830
6831 if column.system:
6832 return None
6833
6834 text = self.get_column_specification(column, first_pk=first_pk)
6835 const = " ".join(
6836 self.process(constraint) for constraint in column.constraints
6837 )
6838 if const:
6839 text += " " + const
6840
6841 return text
6842
6843 def create_table_constraints(
6844 self, table, _include_foreign_key_constraints=None, **kw
6845 ):
6846 # On some DB order is significant: visit PK first, then the
6847 # other constraints (engine.ReflectionTest.testbasic failed on FB2)
6848 constraints = []
6849 if table.primary_key:
6850 constraints.append(table.primary_key)
6851
6852 all_fkcs = table.foreign_key_constraints
6853 if _include_foreign_key_constraints is not None:
6854 omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints)
6855 else:
6856 omit_fkcs = set()
6857
6858 constraints.extend(
6859 [
6860 c
6861 for c in table._sorted_constraints
6862 if c is not table.primary_key and c not in omit_fkcs
6863 ]
6864 )
6865
6866 return ", \n\t".join(
6867 p
6868 for p in (
6869 self.process(constraint)
6870 for constraint in constraints
6871 if (constraint._should_create_for_compiler(self))
6872 and (
6873 not self.dialect.supports_alter
6874 or not getattr(constraint, "use_alter", False)
6875 )
6876 )
6877 if p is not None
6878 )
6879
6880 def visit_drop_table(self, drop, **kw):
6881 text = "\nDROP TABLE "
6882 if drop.if_exists:
6883 text += "IF EXISTS "
6884 return text + self.preparer.format_table(drop.element)
6885
6886 def visit_drop_view(self, drop, **kw):
6887 return "\nDROP VIEW " + self.preparer.format_table(drop.element)
6888
6889 def _verify_index_table(self, index: Index) -> None:
6890 if index.table is None:
6891 raise exc.CompileError(
6892 "Index '%s' is not associated with any table." % index.name
6893 )
6894
6895 def visit_create_index(
6896 self, create, include_schema=False, include_table_schema=True, **kw
6897 ):
6898 index = create.element
6899 self._verify_index_table(index)
6900 preparer = self.preparer
6901 text = "CREATE "
6902 if index.unique:
6903 text += "UNIQUE "
6904 if index.name is None:
6905 raise exc.CompileError(
6906 "CREATE INDEX requires that the index have a name"
6907 )
6908
6909 text += "INDEX "
6910 if create.if_not_exists:
6911 text += "IF NOT EXISTS "
6912
6913 text += "%s ON %s (%s)" % (
6914 self._prepared_index_name(index, include_schema=include_schema),
6915 preparer.format_table(
6916 index.table, use_schema=include_table_schema
6917 ),
6918 ", ".join(
6919 self.sql_compiler.process(
6920 expr, include_table=False, literal_binds=True
6921 )
6922 for expr in index.expressions
6923 ),
6924 )
6925 return text
6926
6927 def visit_drop_index(self, drop, **kw):
6928 index = drop.element
6929
6930 if index.name is None:
6931 raise exc.CompileError(
6932 "DROP INDEX requires that the index have a name"
6933 )
6934 text = "\nDROP INDEX "
6935 if drop.if_exists:
6936 text += "IF EXISTS "
6937
6938 return text + self._prepared_index_name(index, include_schema=True)
6939
6940 def _prepared_index_name(
6941 self, index: Index, include_schema: bool = False
6942 ) -> str:
6943 if index.table is not None:
6944 effective_schema = self.preparer.schema_for_object(index.table)
6945 else:
6946 effective_schema = None
6947 if include_schema and effective_schema:
6948 schema_name = self.preparer.quote_schema(effective_schema)
6949 else:
6950 schema_name = None
6951
6952 index_name: str = self.preparer.format_index(index)
6953
6954 if schema_name:
6955 index_name = schema_name + "." + index_name
6956 return index_name
6957
6958 def visit_add_constraint(self, create, **kw):
6959 return "ALTER TABLE %s ADD %s" % (
6960 self.preparer.format_table(create.element.table),
6961 self.process(create.element),
6962 )
6963
6964 def visit_set_table_comment(self, create, **kw):
6965 return "COMMENT ON TABLE %s IS %s" % (
6966 self.preparer.format_table(create.element),
6967 self.sql_compiler.render_literal_value(
6968 create.element.comment, sqltypes.String()
6969 ),
6970 )
6971
6972 def visit_drop_table_comment(self, drop, **kw):
6973 return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
6974 drop.element
6975 )
6976
6977 def visit_set_column_comment(self, create, **kw):
6978 return "COMMENT ON COLUMN %s IS %s" % (
6979 self.preparer.format_column(
6980 create.element, use_table=True, use_schema=True
6981 ),
6982 self.sql_compiler.render_literal_value(
6983 create.element.comment, sqltypes.String()
6984 ),
6985 )
6986
6987 def visit_drop_column_comment(self, drop, **kw):
6988 return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
6989 drop.element, use_table=True
6990 )
6991
6992 def visit_set_constraint_comment(self, create, **kw):
6993 raise exc.UnsupportedCompilationError(self, type(create))
6994
6995 def visit_drop_constraint_comment(self, drop, **kw):
6996 raise exc.UnsupportedCompilationError(self, type(drop))
6997
6998 def get_identity_options(self, identity_options):
6999 text = []
7000 if identity_options.increment is not None:
7001 text.append("INCREMENT BY %d" % identity_options.increment)
7002 if identity_options.start is not None:
7003 text.append("START WITH %d" % identity_options.start)
7004 if identity_options.minvalue is not None:
7005 text.append("MINVALUE %d" % identity_options.minvalue)
7006 if identity_options.maxvalue is not None:
7007 text.append("MAXVALUE %d" % identity_options.maxvalue)
7008 if identity_options.nominvalue is not None:
7009 text.append("NO MINVALUE")
7010 if identity_options.nomaxvalue is not None:
7011 text.append("NO MAXVALUE")
7012 if identity_options.cache is not None:
7013 text.append("CACHE %d" % identity_options.cache)
7014 if identity_options.cycle is not None:
7015 text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
7016 return " ".join(text)
7017
7018 def visit_create_sequence(self, create, prefix=None, **kw):
7019 text = "CREATE SEQUENCE "
7020 if create.if_not_exists:
7021 text += "IF NOT EXISTS "
7022 text += self.preparer.format_sequence(create.element)
7023
7024 if prefix:
7025 text += prefix
7026 options = self.get_identity_options(create.element)
7027 if options:
7028 text += " " + options
7029 return text
7030
7031 def visit_drop_sequence(self, drop, **kw):
7032 text = "DROP SEQUENCE "
7033 if drop.if_exists:
7034 text += "IF EXISTS "
7035 return text + self.preparer.format_sequence(drop.element)
7036
7037 def visit_drop_constraint(self, drop, **kw):
7038 constraint = drop.element
7039 if constraint.name is not None:
7040 formatted_name = self.preparer.format_constraint(constraint)
7041 else:
7042 formatted_name = None
7043
7044 if formatted_name is None:
7045 raise exc.CompileError(
7046 "Can't emit DROP CONSTRAINT for constraint %r; "
7047 "it has no name" % drop.element
7048 )
7049 return "ALTER TABLE %s DROP CONSTRAINT %s%s%s" % (
7050 self.preparer.format_table(drop.element.table),
7051 "IF EXISTS " if drop.if_exists else "",
7052 formatted_name,
7053 " CASCADE" if drop.cascade else "",
7054 )
7055
7056 def get_column_specification(self, column, **kwargs):
7057 colspec = (
7058 self.preparer.format_column(column)
7059 + " "
7060 + self.dialect.type_compiler_instance.process(
7061 column.type, type_expression=column
7062 )
7063 )
7064 default = self.get_column_default_string(column)
7065 if default is not None:
7066 colspec += " DEFAULT " + default
7067
7068 if column.computed is not None:
7069 colspec += " " + self.process(column.computed)
7070
7071 if (
7072 column.identity is not None
7073 and self.dialect.supports_identity_columns
7074 ):
7075 colspec += " " + self.process(column.identity)
7076
7077 if not column.nullable and (
7078 not column.identity or not self.dialect.supports_identity_columns
7079 ):
7080 colspec += " NOT NULL"
7081 return colspec
7082
7083 def create_table_suffix(self, table):
7084 return ""
7085
7086 def post_create_table(self, table):
7087 return ""
7088
7089 def get_column_default_string(self, column: Column[Any]) -> Optional[str]:
7090 if isinstance(column.server_default, schema.DefaultClause):
7091 return self.render_default_string(column.server_default.arg)
7092 else:
7093 return None
7094
7095 def render_default_string(self, default: Union[Visitable, str]) -> str:
7096 if isinstance(default, str):
7097 return self.sql_compiler.render_literal_value(
7098 default, sqltypes.STRINGTYPE
7099 )
7100 else:
7101 return self.sql_compiler.process(default, literal_binds=True)
7102
7103 def visit_table_or_column_check_constraint(self, constraint, **kw):
7104 if constraint.is_column_level:
7105 return self.visit_column_check_constraint(constraint)
7106 else:
7107 return self.visit_check_constraint(constraint)
7108
7109 def visit_check_constraint(self, constraint, **kw):
7110 text = ""
7111 if constraint.name is not None:
7112 formatted_name = self.preparer.format_constraint(constraint)
7113 if formatted_name is not None:
7114 text += "CONSTRAINT %s " % formatted_name
7115 text += "CHECK (%s)" % self.sql_compiler.process(
7116 constraint.sqltext, include_table=False, literal_binds=True
7117 )
7118 text += self.define_constraint_deferrability(constraint)
7119 return text
7120
7121 def visit_column_check_constraint(self, constraint, **kw):
7122 text = ""
7123 if constraint.name is not None:
7124 formatted_name = self.preparer.format_constraint(constraint)
7125 if formatted_name is not None:
7126 text += "CONSTRAINT %s " % formatted_name
7127 text += "CHECK (%s)" % self.sql_compiler.process(
7128 constraint.sqltext, include_table=False, literal_binds=True
7129 )
7130 text += self.define_constraint_deferrability(constraint)
7131 return text
7132
7133 def visit_primary_key_constraint(
7134 self, constraint: PrimaryKeyConstraint, **kw: Any
7135 ) -> str:
7136 if len(constraint) == 0:
7137 return ""
7138 text = ""
7139 if constraint.name is not None:
7140 formatted_name = self.preparer.format_constraint(constraint)
7141 if formatted_name is not None:
7142 text += "CONSTRAINT %s " % formatted_name
7143 text += "PRIMARY KEY "
7144 text += "(%s)" % ", ".join(
7145 self.preparer.quote(c.name)
7146 for c in (
7147 constraint.columns_autoinc_first
7148 if constraint._implicit_generated
7149 else constraint.columns
7150 )
7151 )
7152 text += self.define_constraint_deferrability(constraint)
7153 return text
7154
7155 def visit_foreign_key_constraint(self, constraint, **kw):
7156 preparer = self.preparer
7157 text = ""
7158 if constraint.name is not None:
7159 formatted_name = self.preparer.format_constraint(constraint)
7160 if formatted_name is not None:
7161 text += "CONSTRAINT %s " % formatted_name
7162 remote_table = list(constraint.elements)[0].column.table
7163 text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
7164 ", ".join(
7165 preparer.quote(f.parent.name) for f in constraint.elements
7166 ),
7167 self.define_constraint_remote_table(
7168 constraint, remote_table, preparer
7169 ),
7170 ", ".join(
7171 preparer.quote(f.column.name) for f in constraint.elements
7172 ),
7173 )
7174 text += self.define_constraint_match(constraint)
7175 text += self.define_constraint_cascades(constraint)
7176 text += self.define_constraint_deferrability(constraint)
7177 return text
7178
7179 def define_constraint_remote_table(self, constraint, table, preparer):
7180 """Format the remote table clause of a CREATE CONSTRAINT clause."""
7181
7182 return preparer.format_table(table)
7183
7184 def visit_unique_constraint(
7185 self, constraint: UniqueConstraint, **kw: Any
7186 ) -> str:
7187 if len(constraint) == 0:
7188 return ""
7189 text = ""
7190 if constraint.name is not None:
7191 formatted_name = self.preparer.format_constraint(constraint)
7192 if formatted_name is not None:
7193 text += "CONSTRAINT %s " % formatted_name
7194 text += "UNIQUE %s(%s)" % (
7195 self.define_unique_constraint_distinct(constraint, **kw),
7196 ", ".join(self.preparer.quote(c.name) for c in constraint),
7197 )
7198 text += self.define_constraint_deferrability(constraint)
7199 return text
7200
7201 def define_unique_constraint_distinct(
7202 self, constraint: UniqueConstraint, **kw: Any
7203 ) -> str:
7204 return ""
7205
7206 def define_constraint_cascades(
7207 self, constraint: ForeignKeyConstraint
7208 ) -> str:
7209 text = ""
7210 if constraint.ondelete is not None:
7211 text += self.define_constraint_ondelete_cascade(constraint)
7212
7213 if constraint.onupdate is not None:
7214 text += self.define_constraint_onupdate_cascade(constraint)
7215 return text
7216
7217 def define_constraint_ondelete_cascade(
7218 self, constraint: ForeignKeyConstraint
7219 ) -> str:
7220 return " ON DELETE %s" % self.preparer.validate_sql_phrase(
7221 constraint.ondelete, FK_ON_DELETE
7222 )
7223
7224 def define_constraint_onupdate_cascade(
7225 self, constraint: ForeignKeyConstraint
7226 ) -> str:
7227 return " ON UPDATE %s" % self.preparer.validate_sql_phrase(
7228 constraint.onupdate, FK_ON_UPDATE
7229 )
7230
7231 def define_constraint_deferrability(self, constraint: Constraint) -> str:
7232 text = ""
7233 if constraint.deferrable is not None:
7234 if constraint.deferrable:
7235 text += " DEFERRABLE"
7236 else:
7237 text += " NOT DEFERRABLE"
7238 if constraint.initially is not None:
7239 text += " INITIALLY %s" % self.preparer.validate_sql_phrase(
7240 constraint.initially, FK_INITIALLY
7241 )
7242 return text
7243
7244 def define_constraint_match(self, constraint):
7245 text = ""
7246 if constraint.match is not None:
7247 text += " MATCH %s" % constraint.match
7248 return text
7249
7250 def visit_computed_column(self, generated, **kw):
7251 text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
7252 generated.sqltext, include_table=False, literal_binds=True
7253 )
7254 if generated.persisted is True:
7255 text += " STORED"
7256 elif generated.persisted is False:
7257 text += " VIRTUAL"
7258 return text
7259
7260 def visit_identity_column(self, identity, **kw):
7261 text = "GENERATED %s AS IDENTITY" % (
7262 "ALWAYS" if identity.always else "BY DEFAULT",
7263 )
7264 options = self.get_identity_options(identity)
7265 if options:
7266 text += " (%s)" % options
7267 return text
7268
7269
7270class GenericTypeCompiler(TypeCompiler):
7271 def visit_FLOAT(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
7272 return "FLOAT"
7273
7274 def visit_DOUBLE(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
7275 return "DOUBLE"
7276
7277 def visit_DOUBLE_PRECISION(
7278 self, type_: sqltypes.DOUBLE_PRECISION[Any], **kw: Any
7279 ) -> str:
7280 return "DOUBLE PRECISION"
7281
7282 def visit_REAL(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
7283 return "REAL"
7284
7285 def visit_NUMERIC(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
7286 if type_.precision is None:
7287 return "NUMERIC"
7288 elif type_.scale is None:
7289 return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
7290 else:
7291 return "NUMERIC(%(precision)s, %(scale)s)" % {
7292 "precision": type_.precision,
7293 "scale": type_.scale,
7294 }
7295
7296 def visit_DECIMAL(self, type_: sqltypes.DECIMAL[Any], **kw: Any) -> str:
7297 if type_.precision is None:
7298 return "DECIMAL"
7299 elif type_.scale is None:
7300 return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
7301 else:
7302 return "DECIMAL(%(precision)s, %(scale)s)" % {
7303 "precision": type_.precision,
7304 "scale": type_.scale,
7305 }
7306
7307 def visit_INTEGER(self, type_: sqltypes.Integer, **kw: Any) -> str:
7308 return "INTEGER"
7309
7310 def visit_SMALLINT(self, type_: sqltypes.SmallInteger, **kw: Any) -> str:
7311 return "SMALLINT"
7312
7313 def visit_BIGINT(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
7314 return "BIGINT"
7315
7316 def visit_TIMESTAMP(self, type_: sqltypes.TIMESTAMP, **kw: Any) -> str:
7317 return "TIMESTAMP"
7318
7319 def visit_DATETIME(self, type_: sqltypes.DateTime, **kw: Any) -> str:
7320 return "DATETIME"
7321
7322 def visit_DATE(self, type_: sqltypes.Date, **kw: Any) -> str:
7323 return "DATE"
7324
7325 def visit_TIME(self, type_: sqltypes.Time, **kw: Any) -> str:
7326 return "TIME"
7327
7328 def visit_CLOB(self, type_: sqltypes.CLOB, **kw: Any) -> str:
7329 return "CLOB"
7330
7331 def visit_NCLOB(self, type_: sqltypes.Text, **kw: Any) -> str:
7332 return "NCLOB"
7333
7334 def _render_string_type(
7335 self, name: str, length: Optional[int], collation: Optional[str]
7336 ) -> str:
7337 text = name
7338 if length:
7339 text += f"({length})"
7340 if collation:
7341 text += f' COLLATE "{collation}"'
7342 return text
7343
7344 def visit_CHAR(self, type_: sqltypes.CHAR, **kw: Any) -> str:
7345 return self._render_string_type("CHAR", type_.length, type_.collation)
7346
7347 def visit_NCHAR(self, type_: sqltypes.NCHAR, **kw: Any) -> str:
7348 return self._render_string_type("NCHAR", type_.length, type_.collation)
7349
7350 def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
7351 return self._render_string_type(
7352 "VARCHAR", type_.length, type_.collation
7353 )
7354
7355 def visit_NVARCHAR(self, type_: sqltypes.NVARCHAR, **kw: Any) -> str:
7356 return self._render_string_type(
7357 "NVARCHAR", type_.length, type_.collation
7358 )
7359
7360 def visit_TEXT(self, type_: sqltypes.Text, **kw: Any) -> str:
7361 return self._render_string_type("TEXT", type_.length, type_.collation)
7362
7363 def visit_UUID(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
7364 return "UUID"
7365
7366 def visit_BLOB(self, type_: sqltypes.LargeBinary, **kw: Any) -> str:
7367 return "BLOB"
7368
7369 def visit_BINARY(self, type_: sqltypes.BINARY, **kw: Any) -> str:
7370 return "BINARY" + (type_.length and "(%d)" % type_.length or "")
7371
7372 def visit_VARBINARY(self, type_: sqltypes.VARBINARY, **kw: Any) -> str:
7373 return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
7374
7375 def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str:
7376 return "BOOLEAN"
7377
7378 def visit_uuid(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
7379 if not type_.native_uuid or not self.dialect.supports_native_uuid:
7380 return self._render_string_type("CHAR", length=32, collation=None)
7381 else:
7382 return self.visit_UUID(type_, **kw)
7383
7384 def visit_large_binary(
7385 self, type_: sqltypes.LargeBinary, **kw: Any
7386 ) -> str:
7387 return self.visit_BLOB(type_, **kw)
7388
7389 def visit_boolean(self, type_: sqltypes.Boolean, **kw: Any) -> str:
7390 return self.visit_BOOLEAN(type_, **kw)
7391
7392 def visit_time(self, type_: sqltypes.Time, **kw: Any) -> str:
7393 return self.visit_TIME(type_, **kw)
7394
7395 def visit_datetime(self, type_: sqltypes.DateTime, **kw: Any) -> str:
7396 return self.visit_DATETIME(type_, **kw)
7397
7398 def visit_date(self, type_: sqltypes.Date, **kw: Any) -> str:
7399 return self.visit_DATE(type_, **kw)
7400
7401 def visit_big_integer(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
7402 return self.visit_BIGINT(type_, **kw)
7403
7404 def visit_small_integer(
7405 self, type_: sqltypes.SmallInteger, **kw: Any
7406 ) -> str:
7407 return self.visit_SMALLINT(type_, **kw)
7408
7409 def visit_integer(self, type_: sqltypes.Integer, **kw: Any) -> str:
7410 return self.visit_INTEGER(type_, **kw)
7411
7412 def visit_real(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
7413 return self.visit_REAL(type_, **kw)
7414
7415 def visit_float(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
7416 return self.visit_FLOAT(type_, **kw)
7417
7418 def visit_double(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
7419 return self.visit_DOUBLE(type_, **kw)
7420
7421 def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
7422 return self.visit_NUMERIC(type_, **kw)
7423
7424 def visit_string(self, type_: sqltypes.String, **kw: Any) -> str:
7425 return self.visit_VARCHAR(type_, **kw)
7426
7427 def visit_unicode(self, type_: sqltypes.Unicode, **kw: Any) -> str:
7428 return self.visit_VARCHAR(type_, **kw)
7429
7430 def visit_text(self, type_: sqltypes.Text, **kw: Any) -> str:
7431 return self.visit_TEXT(type_, **kw)
7432
7433 def visit_unicode_text(
7434 self, type_: sqltypes.UnicodeText, **kw: Any
7435 ) -> str:
7436 return self.visit_TEXT(type_, **kw)
7437
7438 def visit_enum(self, type_: sqltypes.Enum, **kw: Any) -> str:
7439 return self.visit_VARCHAR(type_, **kw)
7440
7441 def visit_null(self, type_, **kw):
7442 raise exc.CompileError(
7443 "Can't generate DDL for %r; "
7444 "did you forget to specify a "
7445 "type on this Column?" % type_
7446 )
7447
7448 def visit_type_decorator(
7449 self, type_: TypeDecorator[Any], **kw: Any
7450 ) -> str:
7451 return self.process(type_.type_engine(self.dialect), **kw)
7452
7453 def visit_user_defined(
7454 self, type_: UserDefinedType[Any], **kw: Any
7455 ) -> str:
7456 return type_.get_col_spec(**kw)
7457
7458
7459class StrSQLTypeCompiler(GenericTypeCompiler):
7460 def process(self, type_, **kw):
7461 try:
7462 _compiler_dispatch = type_._compiler_dispatch
7463 except AttributeError:
7464 return self._visit_unknown(type_, **kw)
7465 else:
7466 return _compiler_dispatch(self, **kw)
7467
7468 def __getattr__(self, key):
7469 if key.startswith("visit_"):
7470 return self._visit_unknown
7471 else:
7472 raise AttributeError(key)
7473
7474 def _visit_unknown(self, type_, **kw):
7475 if type_.__class__.__name__ == type_.__class__.__name__.upper():
7476 return type_.__class__.__name__
7477 else:
7478 return repr(type_)
7479
7480 def visit_null(self, type_, **kw):
7481 return "NULL"
7482
7483 def visit_user_defined(self, type_, **kw):
7484 try:
7485 get_col_spec = type_.get_col_spec
7486 except AttributeError:
7487 return repr(type_)
7488 else:
7489 return get_col_spec(**kw)
7490
7491
7492class _SchemaForObjectCallable(Protocol):
7493 def __call__(self, obj: Any, /) -> str: ...
7494
7495
7496class _BindNameForColProtocol(Protocol):
7497 def __call__(self, col: ColumnClause[Any]) -> str: ...
7498
7499
7500class IdentifierPreparer:
7501 """Handle quoting and case-folding of identifiers based on options."""
7502
7503 reserved_words = RESERVED_WORDS
7504
7505 legal_characters = LEGAL_CHARACTERS
7506
7507 illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
7508
7509 initial_quote: str
7510
7511 final_quote: str
7512
7513 _strings: MutableMapping[str, str]
7514
7515 schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema")
7516 """Return the .schema attribute for an object.
7517
7518 For the default IdentifierPreparer, the schema for an object is always
7519 the value of the ".schema" attribute. if the preparer is replaced
7520 with one that has a non-empty schema_translate_map, the value of the
7521 ".schema" attribute is rendered a symbol that will be converted to a
7522 real schema name from the mapping post-compile.
7523
7524 """
7525
7526 _includes_none_schema_translate: bool = False
7527
7528 def __init__(
7529 self,
7530 dialect: Dialect,
7531 initial_quote: str = '"',
7532 final_quote: Optional[str] = None,
7533 escape_quote: str = '"',
7534 quote_case_sensitive_collations: bool = True,
7535 omit_schema: bool = False,
7536 ):
7537 """Construct a new ``IdentifierPreparer`` object.
7538
7539 initial_quote
7540 Character that begins a delimited identifier.
7541
7542 final_quote
7543 Character that ends a delimited identifier. Defaults to
7544 `initial_quote`.
7545
7546 omit_schema
7547 Prevent prepending schema name. Useful for databases that do
7548 not support schemae.
7549 """
7550
7551 self.dialect = dialect
7552 self.initial_quote = initial_quote
7553 self.final_quote = final_quote or self.initial_quote
7554 self.escape_quote = escape_quote
7555 self.escape_to_quote = self.escape_quote * 2
7556 self.omit_schema = omit_schema
7557 self.quote_case_sensitive_collations = quote_case_sensitive_collations
7558 self._strings = {}
7559 self._double_percents = self.dialect.paramstyle in (
7560 "format",
7561 "pyformat",
7562 )
7563
7564 def _with_schema_translate(self, schema_translate_map):
7565 prep = self.__class__.__new__(self.__class__)
7566 prep.__dict__.update(self.__dict__)
7567
7568 includes_none = None in schema_translate_map
7569
7570 def symbol_getter(obj):
7571 name = obj.schema
7572 if obj._use_schema_map and (name is not None or includes_none):
7573 if name is not None and ("[" in name or "]" in name):
7574 raise exc.CompileError(
7575 "Square bracket characters ([]) not supported "
7576 "in schema translate name '%s'" % name
7577 )
7578 return quoted_name(
7579 "__[SCHEMA_%s]" % (name or "_none"), quote=False
7580 )
7581 else:
7582 return obj.schema
7583
7584 prep.schema_for_object = symbol_getter
7585 prep._includes_none_schema_translate = includes_none
7586 return prep
7587
7588 def _render_schema_translates(
7589 self, statement: str, schema_translate_map: SchemaTranslateMapType
7590 ) -> str:
7591 d = schema_translate_map
7592 if None in d:
7593 if not self._includes_none_schema_translate:
7594 raise exc.InvalidRequestError(
7595 "schema translate map which previously did not have "
7596 "`None` present as a key now has `None` present; compiled "
7597 "statement may lack adequate placeholders. Please use "
7598 "consistent keys in successive "
7599 "schema_translate_map dictionaries."
7600 )
7601
7602 d["_none"] = d[None] # type: ignore[index]
7603
7604 def replace(m):
7605 name = m.group(2)
7606 if name in d:
7607 effective_schema = d[name]
7608 else:
7609 if name in (None, "_none"):
7610 raise exc.InvalidRequestError(
7611 "schema translate map which previously had `None` "
7612 "present as a key now no longer has it present; don't "
7613 "know how to apply schema for compiled statement. "
7614 "Please use consistent keys in successive "
7615 "schema_translate_map dictionaries."
7616 )
7617 effective_schema = name
7618
7619 if not effective_schema:
7620 effective_schema = self.dialect.default_schema_name
7621 if not effective_schema:
7622 # TODO: no coverage here
7623 raise exc.CompileError(
7624 "Dialect has no default schema name; can't "
7625 "use None as dynamic schema target."
7626 )
7627 return self.quote_schema(effective_schema)
7628
7629 return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
7630
7631 def _escape_identifier(self, value: str) -> str:
7632 """Escape an identifier.
7633
7634 Subclasses should override this to provide database-dependent
7635 escaping behavior.
7636 """
7637
7638 value = value.replace(self.escape_quote, self.escape_to_quote)
7639 if self._double_percents:
7640 value = value.replace("%", "%%")
7641 return value
7642
7643 def _unescape_identifier(self, value: str) -> str:
7644 """Canonicalize an escaped identifier.
7645
7646 Subclasses should override this to provide database-dependent
7647 unescaping behavior that reverses _escape_identifier.
7648 """
7649
7650 return value.replace(self.escape_to_quote, self.escape_quote)
7651
7652 def validate_sql_phrase(self, element, reg):
7653 """keyword sequence filter.
7654
7655 a filter for elements that are intended to represent keyword sequences,
7656 such as "INITIALLY", "INITIALLY DEFERRED", etc. no special characters
7657 should be present.
7658
7659 """
7660
7661 if element is not None and not reg.match(element):
7662 raise exc.CompileError(
7663 "Unexpected SQL phrase: %r (matching against %r)"
7664 % (element, reg.pattern)
7665 )
7666 return element
7667
7668 def quote_identifier(self, value: str) -> str:
7669 """Quote an identifier.
7670
7671 Subclasses should override this to provide database-dependent
7672 quoting behavior.
7673 """
7674
7675 return (
7676 self.initial_quote
7677 + self._escape_identifier(value)
7678 + self.final_quote
7679 )
7680
7681 def _requires_quotes(self, value: str) -> bool:
7682 """Return True if the given identifier requires quoting."""
7683 lc_value = value.lower()
7684 return (
7685 lc_value in self.reserved_words
7686 or value[0] in self.illegal_initial_characters
7687 or not self.legal_characters.match(str(value))
7688 or (lc_value != value)
7689 )
7690
7691 def _requires_quotes_illegal_chars(self, value):
7692 """Return True if the given identifier requires quoting, but
7693 not taking case convention into account."""
7694 return not self.legal_characters.match(str(value))
7695
7696 def quote_schema(self, schema: str) -> str:
7697 """Conditionally quote a schema name.
7698
7699
7700 The name is quoted if it is a reserved word, contains quote-necessary
7701 characters, or is an instance of :class:`.quoted_name` which includes
7702 ``quote`` set to ``True``.
7703
7704 Subclasses can override this to provide database-dependent
7705 quoting behavior for schema names.
7706
7707 :param schema: string schema name
7708 """
7709 return self.quote(schema)
7710
7711 def quote(self, ident: str) -> str:
7712 """Conditionally quote an identifier.
7713
7714 The identifier is quoted if it is a reserved word, contains
7715 quote-necessary characters, or is an instance of
7716 :class:`.quoted_name` which includes ``quote`` set to ``True``.
7717
7718 Subclasses can override this to provide database-dependent
7719 quoting behavior for identifier names.
7720
7721 :param ident: string identifier
7722 """
7723 force = getattr(ident, "quote", None)
7724
7725 if force is None:
7726 if ident in self._strings:
7727 return self._strings[ident]
7728 else:
7729 if self._requires_quotes(ident):
7730 self._strings[ident] = self.quote_identifier(ident)
7731 else:
7732 self._strings[ident] = ident
7733 return self._strings[ident]
7734 elif force:
7735 return self.quote_identifier(ident)
7736 else:
7737 return ident
7738
7739 def format_collation(self, collation_name):
7740 if self.quote_case_sensitive_collations:
7741 return self.quote(collation_name)
7742 else:
7743 return collation_name
7744
7745 def format_sequence(
7746 self, sequence: schema.Sequence, use_schema: bool = True
7747 ) -> str:
7748 name = self.quote(sequence.name)
7749
7750 effective_schema = self.schema_for_object(sequence)
7751
7752 if (
7753 not self.omit_schema
7754 and use_schema
7755 and effective_schema is not None
7756 ):
7757 name = self.quote_schema(effective_schema) + "." + name
7758 return name
7759
7760 def format_label(
7761 self, label: Label[Any], name: Optional[str] = None
7762 ) -> str:
7763 return self.quote(name or label.name)
7764
7765 def format_alias(
7766 self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None
7767 ) -> str:
7768 if name is None:
7769 assert alias is not None
7770 return self.quote(alias.name)
7771 else:
7772 return self.quote(name)
7773
7774 def format_savepoint(self, savepoint, name=None):
7775 # Running the savepoint name through quoting is unnecessary
7776 # for all known dialects. This is here to support potential
7777 # third party use cases
7778 ident = name or savepoint.ident
7779 if self._requires_quotes(ident):
7780 ident = self.quote_identifier(ident)
7781 return ident
7782
7783 @util.preload_module("sqlalchemy.sql.naming")
7784 def format_constraint(
7785 self, constraint: Union[Constraint, Index], _alembic_quote: bool = True
7786 ) -> Optional[str]:
7787 naming = util.preloaded.sql_naming
7788
7789 if constraint.name is _NONE_NAME:
7790 name = naming._constraint_name_for_table(
7791 constraint, constraint.table
7792 )
7793
7794 if name is None:
7795 return None
7796 else:
7797 name = constraint.name
7798
7799 assert name is not None
7800 if constraint.__visit_name__ == "index":
7801 return self.truncate_and_render_index_name(
7802 name, _alembic_quote=_alembic_quote
7803 )
7804 else:
7805 return self.truncate_and_render_constraint_name(
7806 name, _alembic_quote=_alembic_quote
7807 )
7808
7809 def truncate_and_render_index_name(
7810 self, name: str, _alembic_quote: bool = True
7811 ) -> str:
7812 # calculate these at format time so that ad-hoc changes
7813 # to dialect.max_identifier_length etc. can be reflected
7814 # as IdentifierPreparer is long lived
7815 max_ = (
7816 self.dialect.max_index_name_length
7817 or self.dialect.max_identifier_length
7818 )
7819 return self._truncate_and_render_maxlen_name(
7820 name, max_, _alembic_quote
7821 )
7822
7823 def truncate_and_render_constraint_name(
7824 self, name: str, _alembic_quote: bool = True
7825 ) -> str:
7826 # calculate these at format time so that ad-hoc changes
7827 # to dialect.max_identifier_length etc. can be reflected
7828 # as IdentifierPreparer is long lived
7829 max_ = (
7830 self.dialect.max_constraint_name_length
7831 or self.dialect.max_identifier_length
7832 )
7833 return self._truncate_and_render_maxlen_name(
7834 name, max_, _alembic_quote
7835 )
7836
7837 def _truncate_and_render_maxlen_name(
7838 self, name: str, max_: int, _alembic_quote: bool
7839 ) -> str:
7840 if isinstance(name, elements._truncated_label):
7841 if len(name) > max_:
7842 name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
7843 else:
7844 self.dialect.validate_identifier(name)
7845
7846 if not _alembic_quote:
7847 return name
7848 else:
7849 return self.quote(name)
7850
7851 def format_index(self, index: Index) -> str:
7852 name = self.format_constraint(index)
7853 assert name is not None
7854 return name
7855
7856 def format_table(
7857 self,
7858 table: FromClause,
7859 use_schema: bool = True,
7860 name: Optional[str] = None,
7861 ) -> str:
7862 """Prepare a quoted table and schema name."""
7863 if name is None:
7864 if TYPE_CHECKING:
7865 assert isinstance(table, NamedFromClause)
7866 name = table.name
7867
7868 result = self.quote(name)
7869
7870 effective_schema = self.schema_for_object(table)
7871
7872 if not self.omit_schema and use_schema and effective_schema:
7873 result = self.quote_schema(effective_schema) + "." + result
7874 return result
7875
7876 def format_schema(self, name):
7877 """Prepare a quoted schema name."""
7878
7879 return self.quote(name)
7880
7881 def format_label_name(
7882 self,
7883 name,
7884 anon_map=None,
7885 ):
7886 """Prepare a quoted column name."""
7887
7888 if anon_map is not None and isinstance(
7889 name, elements._truncated_label
7890 ):
7891 name = name.apply_map(anon_map)
7892
7893 return self.quote(name)
7894
7895 def format_column(
7896 self,
7897 column: ColumnElement[Any],
7898 use_table: bool = False,
7899 name: Optional[str] = None,
7900 table_name: Optional[str] = None,
7901 use_schema: bool = False,
7902 anon_map: Optional[Mapping[str, Any]] = None,
7903 ) -> str:
7904 """Prepare a quoted column name."""
7905
7906 if name is None:
7907 name = column.name
7908 assert name is not None
7909
7910 if anon_map is not None and isinstance(
7911 name, elements._truncated_label
7912 ):
7913 name = name.apply_map(anon_map)
7914
7915 if not getattr(column, "is_literal", False):
7916 if use_table:
7917 return (
7918 self.format_table(
7919 column.table, use_schema=use_schema, name=table_name
7920 )
7921 + "."
7922 + self.quote(name)
7923 )
7924 else:
7925 return self.quote(name)
7926 else:
7927 # literal textual elements get stuck into ColumnClause a lot,
7928 # which shouldn't get quoted
7929
7930 if use_table:
7931 return (
7932 self.format_table(
7933 column.table, use_schema=use_schema, name=table_name
7934 )
7935 + "."
7936 + name
7937 )
7938 else:
7939 return name
7940
7941 def format_table_seq(self, table, use_schema=True):
7942 """Format table name and schema as a tuple."""
7943
7944 # Dialects with more levels in their fully qualified references
7945 # ('database', 'owner', etc.) could override this and return
7946 # a longer sequence.
7947
7948 effective_schema = self.schema_for_object(table)
7949
7950 if not self.omit_schema and use_schema and effective_schema:
7951 return (
7952 self.quote_schema(effective_schema),
7953 self.format_table(table, use_schema=False),
7954 )
7955 else:
7956 return (self.format_table(table, use_schema=False),)
7957
7958 @util.memoized_property
7959 def _r_identifiers(self):
7960 initial, final, escaped_final = (
7961 re.escape(s)
7962 for s in (
7963 self.initial_quote,
7964 self.final_quote,
7965 self._escape_identifier(self.final_quote),
7966 )
7967 )
7968 r = re.compile(
7969 r"(?:"
7970 r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
7971 r"|([^\.]+))(?=\.|$))+"
7972 % {"initial": initial, "final": final, "escaped": escaped_final}
7973 )
7974 return r
7975
7976 def unformat_identifiers(self, identifiers: str) -> Sequence[str]:
7977 """Unpack 'schema.table.column'-like strings into components."""
7978
7979 r = self._r_identifiers
7980 return [
7981 self._unescape_identifier(i)
7982 for i in [a or b for a, b in r.findall(identifiers)]
7983 ]