1# sql/compiler.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Base SQL and DDL compiler implementations.
10
11Classes provided include:
12
13:class:`.compiler.SQLCompiler` - renders SQL
14strings
15
16:class:`.compiler.DDLCompiler` - renders DDL
17(data definition language) strings
18
19:class:`.compiler.GenericTypeCompiler` - renders
20type specification strings.
21
22To generate user-defined SQL strings, see
23:doc:`/ext/compiler`.
24
25"""
26from __future__ import annotations
27
28import collections
29import collections.abc as collections_abc
30import contextlib
31from enum import IntEnum
32import functools
33import itertools
34import operator
35import re
36from time import perf_counter
37import typing
38from typing import Any
39from typing import Callable
40from typing import cast
41from typing import ClassVar
42from typing import Dict
43from typing import FrozenSet
44from typing import Iterable
45from typing import Iterator
46from typing import List
47from typing import Literal
48from typing import Mapping
49from typing import MutableMapping
50from typing import NamedTuple
51from typing import NoReturn
52from typing import Optional
53from typing import Pattern
54from typing import Protocol
55from typing import Sequence
56from typing import Set
57from typing import Tuple
58from typing import Type
59from typing import TYPE_CHECKING
60from typing import TypedDict
61from typing import Union
62
63from . import base
64from . import coercions
65from . import crud
66from . import elements
67from . import functions
68from . import operators
69from . import roles
70from . import schema
71from . import selectable
72from . import sqltypes
73from . import util as sql_util
74from ._typing import is_column_element
75from ._typing import is_dml
76from .base import _de_clone
77from .base import _from_objects
78from .base import _NONE_NAME
79from .base import _SentinelDefaultCharacterization
80from .base import NO_ARG
81from .elements import quoted_name
82from .sqltypes import TupleType
83from .visitors import prefix_anon_map
84from .. import exc
85from .. import util
86from ..util import FastIntFlag
87from ..util.typing import Self
88from ..util.typing import TupleAny
89from ..util.typing import Unpack
90
91if typing.TYPE_CHECKING:
92 from .annotation import _AnnotationDict
93 from .base import _AmbiguousTableNameMap
94 from .base import CompileState
95 from .base import Executable
96 from .cache_key import CacheKey
97 from .ddl import CreateTableAs
98 from .ddl import ExecutableDDLElement
99 from .dml import Delete
100 from .dml import Insert
101 from .dml import Update
102 from .dml import UpdateBase
103 from .dml import UpdateDMLState
104 from .dml import ValuesBase
105 from .elements import _truncated_label
106 from .elements import BinaryExpression
107 from .elements import BindParameter
108 from .elements import ClauseElement
109 from .elements import ColumnClause
110 from .elements import ColumnElement
111 from .elements import False_
112 from .elements import Label
113 from .elements import Null
114 from .elements import True_
115 from .functions import Function
116 from .schema import Column
117 from .schema import Constraint
118 from .schema import ForeignKeyConstraint
119 from .schema import Index
120 from .schema import PrimaryKeyConstraint
121 from .schema import Table
122 from .schema import UniqueConstraint
123 from .selectable import _ColumnsClauseElement
124 from .selectable import AliasedReturnsRows
125 from .selectable import CompoundSelectState
126 from .selectable import CTE
127 from .selectable import FromClause
128 from .selectable import NamedFromClause
129 from .selectable import ReturnsRows
130 from .selectable import Select
131 from .selectable import SelectState
132 from .type_api import _BindProcessorType
133 from .type_api import TypeDecorator
134 from .type_api import TypeEngine
135 from .type_api import UserDefinedType
136 from .visitors import Visitable
137 from ..engine.cursor import CursorResultMetaData
138 from ..engine.interfaces import _CoreSingleExecuteParams
139 from ..engine.interfaces import _DBAPIAnyExecuteParams
140 from ..engine.interfaces import _DBAPIMultiExecuteParams
141 from ..engine.interfaces import _DBAPISingleExecuteParams
142 from ..engine.interfaces import _ExecuteOptions
143 from ..engine.interfaces import _GenericSetInputSizesType
144 from ..engine.interfaces import _MutableCoreSingleExecuteParams
145 from ..engine.interfaces import Dialect
146 from ..engine.interfaces import SchemaTranslateMapType
147
148
149_FromHintsType = Dict["FromClause", str]
150
151RESERVED_WORDS = {
152 "all",
153 "analyse",
154 "analyze",
155 "and",
156 "any",
157 "array",
158 "as",
159 "asc",
160 "asymmetric",
161 "authorization",
162 "between",
163 "binary",
164 "both",
165 "case",
166 "cast",
167 "check",
168 "collate",
169 "column",
170 "constraint",
171 "create",
172 "cross",
173 "current_date",
174 "current_role",
175 "current_time",
176 "current_timestamp",
177 "current_user",
178 "default",
179 "deferrable",
180 "desc",
181 "distinct",
182 "do",
183 "else",
184 "end",
185 "except",
186 "false",
187 "for",
188 "foreign",
189 "freeze",
190 "from",
191 "full",
192 "grant",
193 "group",
194 "having",
195 "ilike",
196 "in",
197 "initially",
198 "inner",
199 "intersect",
200 "into",
201 "is",
202 "isnull",
203 "join",
204 "leading",
205 "left",
206 "like",
207 "limit",
208 "localtime",
209 "localtimestamp",
210 "natural",
211 "new",
212 "not",
213 "notnull",
214 "null",
215 "off",
216 "offset",
217 "old",
218 "on",
219 "only",
220 "or",
221 "order",
222 "outer",
223 "overlaps",
224 "placing",
225 "primary",
226 "references",
227 "right",
228 "select",
229 "session_user",
230 "set",
231 "similar",
232 "some",
233 "symmetric",
234 "table",
235 "then",
236 "to",
237 "trailing",
238 "true",
239 "union",
240 "unique",
241 "user",
242 "using",
243 "verbose",
244 "when",
245 "where",
246}
247
248LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
249LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
250ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
251
252FK_ON_DELETE = re.compile(
253 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
254)
255FK_ON_UPDATE = re.compile(
256 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
257)
258FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
259BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
260BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
261
262_pyformat_template = "%%(%(name)s)s"
263BIND_TEMPLATES = {
264 "pyformat": _pyformat_template,
265 "qmark": "?",
266 "format": "%%s",
267 "numeric": ":[_POSITION]",
268 "numeric_dollar": "$[_POSITION]",
269 "named": ":%(name)s",
270}
271
272
273OPERATORS = {
274 # binary
275 operators.and_: " AND ",
276 operators.or_: " OR ",
277 operators.add: " + ",
278 operators.mul: " * ",
279 operators.sub: " - ",
280 operators.mod: " % ",
281 operators.neg: "-",
282 operators.lt: " < ",
283 operators.le: " <= ",
284 operators.ne: " != ",
285 operators.gt: " > ",
286 operators.ge: " >= ",
287 operators.eq: " = ",
288 operators.is_distinct_from: " IS DISTINCT FROM ",
289 operators.is_not_distinct_from: " IS NOT DISTINCT FROM ",
290 operators.concat_op: " || ",
291 operators.match_op: " MATCH ",
292 operators.not_match_op: " NOT MATCH ",
293 operators.in_op: " IN ",
294 operators.not_in_op: " NOT IN ",
295 operators.comma_op: ", ",
296 operators.from_: " FROM ",
297 operators.as_: " AS ",
298 operators.is_: " IS ",
299 operators.is_not: " IS NOT ",
300 operators.collate: " COLLATE ",
301 # unary
302 operators.exists: "EXISTS ",
303 operators.distinct_op: "DISTINCT ",
304 operators.inv: "NOT ",
305 operators.any_op: "ANY ",
306 operators.all_op: "ALL ",
307 # modifiers
308 operators.desc_op: " DESC",
309 operators.asc_op: " ASC",
310 operators.nulls_first_op: " NULLS FIRST",
311 operators.nulls_last_op: " NULLS LAST",
312 # bitwise
313 operators.bitwise_xor_op: " ^ ",
314 operators.bitwise_or_op: " | ",
315 operators.bitwise_and_op: " & ",
316 operators.bitwise_not_op: "~",
317 operators.bitwise_lshift_op: " << ",
318 operators.bitwise_rshift_op: " >> ",
319}
320
321FUNCTIONS: Dict[Type[Function[Any]], str] = {
322 functions.coalesce: "coalesce",
323 functions.current_date: "CURRENT_DATE",
324 functions.current_time: "CURRENT_TIME",
325 functions.current_timestamp: "CURRENT_TIMESTAMP",
326 functions.current_user: "CURRENT_USER",
327 functions.localtime: "LOCALTIME",
328 functions.localtimestamp: "LOCALTIMESTAMP",
329 functions.random: "random",
330 functions.sysdate: "sysdate",
331 functions.session_user: "SESSION_USER",
332 functions.user: "USER",
333 functions.cube: "CUBE",
334 functions.rollup: "ROLLUP",
335 functions.grouping_sets: "GROUPING SETS",
336}
337
338
339EXTRACT_MAP = {
340 "month": "month",
341 "day": "day",
342 "year": "year",
343 "second": "second",
344 "hour": "hour",
345 "doy": "doy",
346 "minute": "minute",
347 "quarter": "quarter",
348 "dow": "dow",
349 "week": "week",
350 "epoch": "epoch",
351 "milliseconds": "milliseconds",
352 "microseconds": "microseconds",
353 "timezone_hour": "timezone_hour",
354 "timezone_minute": "timezone_minute",
355}
356
357COMPOUND_KEYWORDS = {
358 selectable._CompoundSelectKeyword.UNION: "UNION",
359 selectable._CompoundSelectKeyword.UNION_ALL: "UNION ALL",
360 selectable._CompoundSelectKeyword.EXCEPT: "EXCEPT",
361 selectable._CompoundSelectKeyword.EXCEPT_ALL: "EXCEPT ALL",
362 selectable._CompoundSelectKeyword.INTERSECT: "INTERSECT",
363 selectable._CompoundSelectKeyword.INTERSECT_ALL: "INTERSECT ALL",
364}
365
366
367class ResultColumnsEntry(NamedTuple):
368 """Tracks a column expression that is expected to be represented
369 in the result rows for this statement.
370
371 This normally refers to the columns clause of a SELECT statement
372 but may also refer to a RETURNING clause, as well as for dialect-specific
373 emulations.
374
375 """
376
377 keyname: str
378 """string name that's expected in cursor.description"""
379
380 name: str
381 """column name, may be labeled"""
382
383 objects: Tuple[Any, ...]
384 """sequence of objects that should be able to locate this column
385 in a RowMapping. This is typically string names and aliases
386 as well as Column objects.
387
388 """
389
390 type: TypeEngine[Any]
391 """Datatype to be associated with this column. This is where
392 the "result processing" logic directly links the compiled statement
393 to the rows that come back from the cursor.
394
395 """
396
397
398class _ResultMapAppender(Protocol):
399 def __call__(
400 self,
401 keyname: str,
402 name: str,
403 objects: Sequence[Any],
404 type_: TypeEngine[Any],
405 ) -> None: ...
406
407
408# integer indexes into ResultColumnsEntry used by cursor.py.
409# some profiling showed integer access faster than named tuple
410RM_RENDERED_NAME: Literal[0] = 0
411RM_NAME: Literal[1] = 1
412RM_OBJECTS: Literal[2] = 2
413RM_TYPE: Literal[3] = 3
414
415
416class _BaseCompilerStackEntry(TypedDict):
417 asfrom_froms: Set[FromClause]
418 correlate_froms: Set[FromClause]
419 selectable: ReturnsRows
420
421
422class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
423 compile_state: CompileState
424 need_result_map_for_nested: bool
425 need_result_map_for_compound: bool
426 select_0: ReturnsRows
427 insert_from_select: Select[Unpack[TupleAny]]
428
429
430class ExpandedState(NamedTuple):
431 """represents state to use when producing "expanded" and
432 "post compile" bound parameters for a statement.
433
434 "expanded" parameters are parameters that are generated at
435 statement execution time to suit a number of parameters passed, the most
436 prominent example being the individual elements inside of an IN expression.
437
438 "post compile" parameters are parameters where the SQL literal value
439 will be rendered into the SQL statement at execution time, rather than
440 being passed as separate parameters to the driver.
441
442 To create an :class:`.ExpandedState` instance, use the
443 :meth:`.SQLCompiler.construct_expanded_state` method on any
444 :class:`.SQLCompiler` instance.
445
446 """
447
448 statement: str
449 """String SQL statement with parameters fully expanded"""
450
451 parameters: _CoreSingleExecuteParams
452 """Parameter dictionary with parameters fully expanded.
453
454 For a statement that uses named parameters, this dictionary will map
455 exactly to the names in the statement. For a statement that uses
456 positional parameters, the :attr:`.ExpandedState.positional_parameters`
457 will yield a tuple with the positional parameter set.
458
459 """
460
461 processors: Mapping[str, _BindProcessorType[Any]]
462 """mapping of bound value processors"""
463
464 positiontup: Optional[Sequence[str]]
465 """Sequence of string names indicating the order of positional
466 parameters"""
467
468 parameter_expansion: Mapping[str, List[str]]
469 """Mapping representing the intermediary link from original parameter
470 name to list of "expanded" parameter names, for those parameters that
471 were expanded."""
472
473 @property
474 def positional_parameters(self) -> Tuple[Any, ...]:
475 """Tuple of positional parameters, for statements that were compiled
476 using a positional paramstyle.
477
478 """
479 if self.positiontup is None:
480 raise exc.InvalidRequestError(
481 "statement does not use a positional paramstyle"
482 )
483 return tuple(self.parameters[key] for key in self.positiontup)
484
485 @property
486 def additional_parameters(self) -> _CoreSingleExecuteParams:
487 """synonym for :attr:`.ExpandedState.parameters`."""
488 return self.parameters
489
490
491class _InsertManyValues(NamedTuple):
492 """represents state to use for executing an "insertmanyvalues" statement.
493
494 The primary consumers of this object are the
495 :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and
496 :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods.
497
498 .. versionadded:: 2.0
499
500 """
501
502 is_default_expr: bool
503 """if True, the statement is of the form
504 ``INSERT INTO TABLE DEFAULT VALUES``, and can't be rewritten as a "batch"
505
506 """
507
508 single_values_expr: str
509 """The rendered "values" clause of the INSERT statement.
510
511 This is typically the parenthesized section e.g. "(?, ?, ?)" or similar.
512 The insertmanyvalues logic uses this string as a search and replace
513 target.
514
515 """
516
517 insert_crud_params: List[crud._CrudParamElementStr]
518 """List of Column / bind names etc. used while rewriting the statement"""
519
520 num_positional_params_counted: int
521 """the number of bound parameters in a single-row statement.
522
523 This count may be larger or smaller than the actual number of columns
524 targeted in the INSERT, as it accommodates for SQL expressions
525 in the values list that may have zero or more parameters embedded
526 within them.
527
528 This count is part of what's used to organize rewritten parameter lists
529 when batching.
530
531 """
532
533 sort_by_parameter_order: bool = False
534 """if the deterministic_returnined_order parameter were used on the
535 insert.
536
537 All of the attributes following this will only be used if this is True.
538
539 """
540
541 includes_upsert_behaviors: bool = False
542 """if True, we have to accommodate for upsert behaviors.
543
544 This will in some cases downgrade "insertmanyvalues" that requests
545 deterministic ordering.
546
547 """
548
549 sentinel_columns: Optional[Sequence[Column[Any]]] = None
550 """List of sentinel columns that were located.
551
552 This list is only here if the INSERT asked for
553 sort_by_parameter_order=True,
554 and dialect-appropriate sentinel columns were located.
555
556 .. versionadded:: 2.0.10
557
558 """
559
560 num_sentinel_columns: int = 0
561 """how many sentinel columns are in the above list, if any.
562
563 This is the same as
564 ``len(sentinel_columns) if sentinel_columns is not None else 0``
565
566 """
567
568 sentinel_param_keys: Optional[Sequence[str]] = None
569 """parameter str keys in each param dictionary / tuple
570 that would link to the client side "sentinel" values for that row, which
571 we can use to match up parameter sets to result rows.
572
573 This is only present if sentinel_columns is present and the INSERT
574 statement actually refers to client side values for these sentinel
575 columns.
576
577 .. versionadded:: 2.0.10
578
579 .. versionchanged:: 2.0.29 - the sequence is now string dictionary keys
580 only, used against the "compiled parameteters" collection before
581 the parameters were converted by bound parameter processors
582
583 """
584
585 implicit_sentinel: bool = False
586 """if True, we have exactly one sentinel column and it uses a server side
587 value, currently has to generate an incrementing integer value.
588
589 The dialect in question would have asserted that it supports receiving
590 these values back and sorting on that value as a means of guaranteeing
591 correlation with the incoming parameter list.
592
593 .. versionadded:: 2.0.10
594
595 """
596
597 embed_values_counter: bool = False
598 """Whether to embed an incrementing integer counter in each parameter
599 set within the VALUES clause as parameters are batched over.
600
601 This is only used for a specific INSERT..SELECT..VALUES..RETURNING syntax
602 where a subquery is used to produce value tuples. Current support
603 includes PostgreSQL, Microsoft SQL Server.
604
605 .. versionadded:: 2.0.10
606
607 """
608
609
610class _InsertManyValuesBatch(NamedTuple):
611 """represents an individual batch SQL statement for insertmanyvalues.
612
613 This is passed through the
614 :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and
615 :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods out
616 to the :class:`.Connection` within the
617 :meth:`.Connection._exec_insertmany_context` method.
618
619 .. versionadded:: 2.0.10
620
621 """
622
623 replaced_statement: str
624 replaced_parameters: _DBAPIAnyExecuteParams
625 processed_setinputsizes: Optional[_GenericSetInputSizesType]
626 batch: Sequence[_DBAPISingleExecuteParams]
627 sentinel_values: Sequence[Tuple[Any, ...]]
628 current_batch_size: int
629 batchnum: int
630 total_batches: int
631 rows_sorted: bool
632 is_downgraded: bool
633
634
635class InsertmanyvaluesSentinelOpts(FastIntFlag):
636 """bitflag enum indicating styles of PK defaults
637 which can work as implicit sentinel columns
638
639 """
640
641 NOT_SUPPORTED = 1
642 AUTOINCREMENT = 2
643 IDENTITY = 4
644 SEQUENCE = 8
645
646 ANY_AUTOINCREMENT = AUTOINCREMENT | IDENTITY | SEQUENCE
647 _SUPPORTED_OR_NOT = NOT_SUPPORTED | ANY_AUTOINCREMENT
648
649 USE_INSERT_FROM_SELECT = 16
650 RENDER_SELECT_COL_CASTS = 64
651
652
653class AggregateOrderByStyle(IntEnum):
654 """Describes backend database's capabilities with ORDER BY for aggregate
655 functions
656
657 .. versionadded:: 2.1
658
659 """
660
661 NONE = 0
662 """database has no ORDER BY for aggregate functions"""
663
664 INLINE = 1
665 """ORDER BY is rendered inside the function's argument list, typically as
666 the last element"""
667
668 WITHIN_GROUP = 2
669 """the WITHIN GROUP (ORDER BY ...) phrase is used for all aggregate
670 functions (not just the ordered set ones)"""
671
672
673class CompilerState(IntEnum):
674 COMPILING = 0
675 """statement is present, compilation phase in progress"""
676
677 STRING_APPLIED = 1
678 """statement is present, string form of the statement has been applied.
679
680 Additional processors by subclasses may still be pending.
681
682 """
683
684 NO_STATEMENT = 2
685 """compiler does not have a statement to compile, is used
686 for method access"""
687
688
689class Linting(IntEnum):
690 """represent preferences for the 'SQL linting' feature.
691
692 this feature currently includes support for flagging cartesian products
693 in SQL statements.
694
695 """
696
697 NO_LINTING = 0
698 "Disable all linting."
699
700 COLLECT_CARTESIAN_PRODUCTS = 1
701 """Collect data on FROMs and cartesian products and gather into
702 'self.from_linter'"""
703
704 WARN_LINTING = 2
705 "Emit warnings for linters that find problems"
706
707 FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING
708 """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS
709 and WARN_LINTING"""
710
711
712NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple(
713 Linting
714)
715
716
717class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
718 """represents current state for the "cartesian product" detection
719 feature."""
720
721 def lint(self, start=None):
722 froms = self.froms
723 if not froms:
724 return None, None
725
726 edges = set(self.edges)
727 the_rest = set(froms)
728
729 if start is not None:
730 start_with = start
731 the_rest.remove(start_with)
732 else:
733 start_with = the_rest.pop()
734
735 stack = collections.deque([start_with])
736
737 while stack and the_rest:
738 node = stack.popleft()
739 the_rest.discard(node)
740
741 # comparison of nodes in edges here is based on hash equality, as
742 # there are "annotated" elements that match the non-annotated ones.
743 # to remove the need for in-python hash() calls, use native
744 # containment routines (e.g. "node in edge", "edge.index(node)")
745 to_remove = {edge for edge in edges if node in edge}
746
747 # appendleft the node in each edge that is not
748 # the one that matched.
749 stack.extendleft(edge[not edge.index(node)] for edge in to_remove)
750 edges.difference_update(to_remove)
751
752 # FROMS left over? boom
753 if the_rest:
754 return the_rest, start_with
755 else:
756 return None, None
757
758 def warn(self, stmt_type="SELECT"):
759 the_rest, start_with = self.lint()
760
761 # FROMS left over? boom
762 if the_rest:
763 froms = the_rest
764 if froms:
765 template = (
766 "{stmt_type} statement has a cartesian product between "
767 "FROM element(s) {froms} and "
768 'FROM element "{start}". Apply join condition(s) '
769 "between each element to resolve."
770 )
771 froms_str = ", ".join(
772 f'"{self.froms[from_]}"' for from_ in froms
773 )
774 message = template.format(
775 stmt_type=stmt_type,
776 froms=froms_str,
777 start=self.froms[start_with],
778 )
779
780 util.warn(message)
781
782
783class Compiled:
784 """Represent a compiled SQL or DDL expression.
785
786 The ``__str__`` method of the ``Compiled`` object should produce
787 the actual text of the statement. ``Compiled`` objects are
788 specific to their underlying database dialect, and also may
789 or may not be specific to the columns referenced within a
790 particular set of bind parameters. In no case should the
791 ``Compiled`` object be dependent on the actual values of those
792 bind parameters, even though it may reference those values as
793 defaults.
794 """
795
796 statement: Optional[ClauseElement] = None
797 "The statement to compile."
798 string: str = ""
799 "The string representation of the ``statement``"
800
801 state: CompilerState
802 """description of the compiler's state"""
803
804 is_sql = False
805 is_ddl = False
806
807 _cached_metadata: Optional[CursorResultMetaData] = None
808
809 _result_columns: Optional[List[ResultColumnsEntry]] = None
810
811 schema_translate_map: Optional[SchemaTranslateMapType] = None
812
813 execution_options: _ExecuteOptions = util.EMPTY_DICT
814 """
815 Execution options propagated from the statement. In some cases,
816 sub-elements of the statement can modify these.
817 """
818
819 preparer: IdentifierPreparer
820
821 _annotations: _AnnotationDict = util.EMPTY_DICT
822
823 compile_state: Optional[CompileState] = None
824 """Optional :class:`.CompileState` object that maintains additional
825 state used by the compiler.
826
827 Major executable objects such as :class:`_expression.Insert`,
828 :class:`_expression.Update`, :class:`_expression.Delete`,
829 :class:`_expression.Select` will generate this
830 state when compiled in order to calculate additional information about the
831 object. For the top level object that is to be executed, the state can be
832 stored here where it can also have applicability towards result set
833 processing.
834
835 .. versionadded:: 1.4
836
837 """
838
839 dml_compile_state: Optional[CompileState] = None
840 """Optional :class:`.CompileState` assigned at the same point that
841 .isinsert, .isupdate, or .isdelete is assigned.
842
843 This will normally be the same object as .compile_state, with the
844 exception of cases like the :class:`.ORMFromStatementCompileState`
845 object.
846
847 .. versionadded:: 1.4.40
848
849 """
850
851 cache_key: Optional[CacheKey] = None
852 """The :class:`.CacheKey` that was generated ahead of creating this
853 :class:`.Compiled` object.
854
855 This is used for routines that need access to the original
856 :class:`.CacheKey` instance generated when the :class:`.Compiled`
857 instance was first cached, typically in order to reconcile
858 the original list of :class:`.BindParameter` objects with a
859 per-statement list that's generated on each call.
860
861 """
862
863 _gen_time: float
864 """Generation time of this :class:`.Compiled`, used for reporting
865 cache stats."""
866
867 def __init__(
868 self,
869 dialect: Dialect,
870 statement: Optional[ClauseElement],
871 schema_translate_map: Optional[SchemaTranslateMapType] = None,
872 render_schema_translate: bool = False,
873 compile_kwargs: Mapping[str, Any] = util.immutabledict(),
874 ):
875 """Construct a new :class:`.Compiled` object.
876
877 :param dialect: :class:`.Dialect` to compile against.
878
879 :param statement: :class:`_expression.ClauseElement` to be compiled.
880
881 :param schema_translate_map: dictionary of schema names to be
882 translated when forming the resultant SQL
883
884 .. seealso::
885
886 :ref:`schema_translating`
887
888 :param compile_kwargs: additional kwargs that will be
889 passed to the initial call to :meth:`.Compiled.process`.
890
891
892 """
893 self.dialect = dialect
894 self.preparer = self.dialect.identifier_preparer
895 if schema_translate_map:
896 self.schema_translate_map = schema_translate_map
897 self.preparer = self.preparer._with_schema_translate(
898 schema_translate_map
899 )
900
901 if statement is not None:
902 self.state = CompilerState.COMPILING
903 self.statement = statement
904 self.can_execute = statement.supports_execution
905 self._annotations = statement._annotations
906 if self.can_execute:
907 if TYPE_CHECKING:
908 assert isinstance(statement, Executable)
909 self.execution_options = statement._execution_options
910 self.string = self.process(self.statement, **compile_kwargs)
911
912 if render_schema_translate:
913 assert schema_translate_map is not None
914 self.string = self.preparer._render_schema_translates(
915 self.string, schema_translate_map
916 )
917
918 self.state = CompilerState.STRING_APPLIED
919 else:
920 self.state = CompilerState.NO_STATEMENT
921
922 self._gen_time = perf_counter()
923
924 def __init_subclass__(cls) -> None:
925 cls._init_compiler_cls()
926 return super().__init_subclass__()
927
928 @classmethod
929 def _init_compiler_cls(cls):
930 pass
931
932 def visit_unsupported_compilation(self, element, err, **kw):
933 raise exc.UnsupportedCompilationError(self, type(element)) from err
934
935 @property
936 def sql_compiler(self) -> SQLCompiler:
937 """Return a Compiled that is capable of processing SQL expressions.
938
939 If this compiler is one, it would likely just return 'self'.
940
941 """
942
943 raise NotImplementedError()
944
945 def process(self, obj: Visitable, **kwargs: Any) -> str:
946 return obj._compiler_dispatch(self, **kwargs)
947
948 def __str__(self) -> str:
949 """Return the string text of the generated SQL or DDL."""
950
951 if self.state is CompilerState.STRING_APPLIED:
952 return self.string
953 else:
954 return ""
955
956 def construct_params(
957 self,
958 params: Optional[_CoreSingleExecuteParams] = None,
959 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
960 escape_names: bool = True,
961 ) -> Optional[_MutableCoreSingleExecuteParams]:
962 """Return the bind params for this compiled object.
963
964 :param params: a dict of string/object pairs whose values will
965 override bind values compiled in to the
966 statement.
967 """
968
969 raise NotImplementedError()
970
971 @property
972 def params(self):
973 """Return the bind params for this compiled object."""
974 return self.construct_params()
975
976
977class TypeCompiler(util.EnsureKWArg):
978 """Produces DDL specification for TypeEngine objects."""
979
980 ensure_kwarg = r"visit_\w+"
981
982 def __init__(self, dialect: Dialect):
983 self.dialect = dialect
984
985 def process(self, type_: TypeEngine[Any], **kw: Any) -> str:
986 if (
987 type_._variant_mapping
988 and self.dialect.name in type_._variant_mapping
989 ):
990 type_ = type_._variant_mapping[self.dialect.name]
991 return type_._compiler_dispatch(self, **kw)
992
993 def visit_unsupported_compilation(
994 self, element: Any, err: Exception, **kw: Any
995 ) -> NoReturn:
996 raise exc.UnsupportedCompilationError(self, element) from err
997
998
999# this was a Visitable, but to allow accurate detection of
1000# column elements this is actually a column element
1001class _CompileLabel(
1002 roles.BinaryElementRole[Any], elements.CompilerColumnElement
1003):
1004 """lightweight label object which acts as an expression.Label."""
1005
1006 __visit_name__ = "label"
1007 __slots__ = "element", "name", "_alt_names"
1008
1009 def __init__(self, col, name, alt_names=()):
1010 self.element = col
1011 self.name = name
1012 self._alt_names = (col,) + alt_names
1013
1014 @property
1015 def proxy_set(self):
1016 return self.element.proxy_set
1017
1018 @property
1019 def type(self):
1020 return self.element.type
1021
1022 def self_group(self, **kw):
1023 return self
1024
1025
1026class aggregate_orderby_inline(
1027 roles.BinaryElementRole[Any], elements.CompilerColumnElement
1028):
1029 """produce ORDER BY inside of function argument lists"""
1030
1031 __visit_name__ = "aggregate_orderby_inline"
1032 __slots__ = "element", "aggregate_order_by"
1033
1034 def __init__(self, element, orderby):
1035 self.element = element
1036 self.aggregate_order_by = orderby
1037
1038 def __iter__(self):
1039 return iter(self.element)
1040
1041 @property
1042 def proxy_set(self):
1043 return self.element.proxy_set
1044
1045 @property
1046 def type(self):
1047 return self.element.type
1048
1049 def self_group(self, **kw):
1050 return self
1051
1052 def _with_binary_element_type(self, type_):
1053 return aggregate_orderby_inline(
1054 self.element._with_binary_element_type(type_),
1055 self.aggregate_order_by,
1056 )
1057
1058
1059class ilike_case_insensitive(
1060 roles.BinaryElementRole[Any], elements.CompilerColumnElement
1061):
1062 """produce a wrapping element for a case-insensitive portion of
1063 an ILIKE construct.
1064
1065 The construct usually renders the ``lower()`` function, but on
1066 PostgreSQL will pass silently with the assumption that "ILIKE"
1067 is being used.
1068
1069 .. versionadded:: 2.0
1070
1071 """
1072
1073 __visit_name__ = "ilike_case_insensitive_operand"
1074 __slots__ = "element", "comparator"
1075
1076 def __init__(self, element):
1077 self.element = element
1078 self.comparator = element.comparator
1079
1080 @property
1081 def proxy_set(self):
1082 return self.element.proxy_set
1083
1084 @property
1085 def type(self):
1086 return self.element.type
1087
1088 def self_group(self, **kw):
1089 return self
1090
1091 def _with_binary_element_type(self, type_):
1092 return ilike_case_insensitive(
1093 self.element._with_binary_element_type(type_)
1094 )
1095
1096
1097class SQLCompiler(Compiled):
1098 """Default implementation of :class:`.Compiled`.
1099
1100 Compiles :class:`_expression.ClauseElement` objects into SQL strings.
1101
1102 """
1103
1104 extract_map = EXTRACT_MAP
1105
1106 bindname_escape_characters: ClassVar[Mapping[str, str]] = (
1107 util.immutabledict(
1108 {
1109 "%": "P",
1110 "(": "A",
1111 ")": "Z",
1112 ":": "C",
1113 ".": "_",
1114 "[": "_",
1115 "]": "_",
1116 " ": "_",
1117 }
1118 )
1119 )
1120 """A mapping (e.g. dict or similar) containing a lookup of
1121 characters keyed to replacement characters which will be applied to all
1122 'bind names' used in SQL statements as a form of 'escaping'; the given
1123 characters are replaced entirely with the 'replacement' character when
1124 rendered in the SQL statement, and a similar translation is performed
1125 on the incoming names used in parameter dictionaries passed to methods
1126 like :meth:`_engine.Connection.execute`.
1127
1128 This allows bound parameter names used in :func:`_sql.bindparam` and
1129 other constructs to have any arbitrary characters present without any
1130 concern for characters that aren't allowed at all on the target database.
1131
1132 Third party dialects can establish their own dictionary here to replace the
1133 default mapping, which will ensure that the particular characters in the
1134 mapping will never appear in a bound parameter name.
1135
1136 The dictionary is evaluated at **class creation time**, so cannot be
1137 modified at runtime; it must be present on the class when the class
1138 is first declared.
1139
1140 Note that for dialects that have additional bound parameter rules such
1141 as additional restrictions on leading characters, the
1142 :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented.
1143 See the cx_Oracle compiler for an example of this.
1144
1145 .. versionadded:: 2.0.0rc1
1146
1147 """
1148
1149 _bind_translate_re: ClassVar[Pattern[str]]
1150 _bind_translate_chars: ClassVar[Mapping[str, str]]
1151
1152 is_sql = True
1153
1154 compound_keywords = COMPOUND_KEYWORDS
1155
1156 isdelete: bool = False
1157 isinsert: bool = False
1158 isupdate: bool = False
1159 """class-level defaults which can be set at the instance
1160 level to define if this Compiled instance represents
1161 INSERT/UPDATE/DELETE
1162 """
1163
1164 postfetch: Optional[List[Column[Any]]]
1165 """list of columns that can be post-fetched after INSERT or UPDATE to
1166 receive server-updated values"""
1167
1168 insert_prefetch: Sequence[Column[Any]] = ()
1169 """list of columns for which default values should be evaluated before
1170 an INSERT takes place"""
1171
1172 update_prefetch: Sequence[Column[Any]] = ()
1173 """list of columns for which onupdate default values should be evaluated
1174 before an UPDATE takes place"""
1175
1176 implicit_returning: Optional[Sequence[ColumnElement[Any]]] = None
1177 """list of "implicit" returning columns for a toplevel INSERT or UPDATE
1178 statement, used to receive newly generated values of columns.
1179
1180 .. versionadded:: 2.0 ``implicit_returning`` replaces the previous
1181 ``returning`` collection, which was not a generalized RETURNING
1182 collection and instead was in fact specific to the "implicit returning"
1183 feature.
1184
1185 """
1186
1187 isplaintext: bool = False
1188
1189 binds: Dict[str, BindParameter[Any]]
1190 """a dictionary of bind parameter keys to BindParameter instances."""
1191
1192 bind_names: Dict[BindParameter[Any], str]
1193 """a dictionary of BindParameter instances to "compiled" names
1194 that are actually present in the generated SQL"""
1195
1196 stack: List[_CompilerStackEntry]
1197 """major statements such as SELECT, INSERT, UPDATE, DELETE are
1198 tracked in this stack using an entry format."""
1199
1200 returning_precedes_values: bool = False
1201 """set to True classwide to generate RETURNING
1202 clauses before the VALUES or WHERE clause (i.e. MSSQL)
1203 """
1204
1205 render_table_with_column_in_update_from: bool = False
1206 """set to True classwide to indicate the SET clause
1207 in a multi-table UPDATE statement should qualify
1208 columns with the table name (i.e. MySQL only)
1209 """
1210
1211 ansi_bind_rules: bool = False
1212 """SQL 92 doesn't allow bind parameters to be used
1213 in the columns clause of a SELECT, nor does it allow
1214 ambiguous expressions like "? = ?". A compiler
1215 subclass can set this flag to False if the target
1216 driver/DB enforces this
1217 """
1218
1219 bindtemplate: str
1220 """template to render bound parameters based on paramstyle."""
1221
1222 compilation_bindtemplate: str
1223 """template used by compiler to render parameters before positional
1224 paramstyle application"""
1225
1226 _numeric_binds_identifier_char: str
1227 """Character that's used to as the identifier of a numerical bind param.
1228 For example if this char is set to ``$``, numerical binds will be rendered
1229 in the form ``$1, $2, $3``.
1230 """
1231
1232 _result_columns: List[ResultColumnsEntry]
1233 """relates label names in the final SQL to a tuple of local
1234 column/label name, ColumnElement object (if any) and
1235 TypeEngine. CursorResult uses this for type processing and
1236 column targeting"""
1237
1238 _textual_ordered_columns: bool = False
1239 """tell the result object that the column names as rendered are important,
1240 but they are also "ordered" vs. what is in the compiled object here.
1241
1242 As of 1.4.42 this condition is only present when the statement is a
1243 TextualSelect, e.g. text("....").columns(...), where it is required
1244 that the columns are considered positionally and not by name.
1245
1246 """
1247
1248 _ad_hoc_textual: bool = False
1249 """tell the result that we encountered text() or '*' constructs in the
1250 middle of the result columns, but we also have compiled columns, so
1251 if the number of columns in cursor.description does not match how many
1252 expressions we have, that means we can't rely on positional at all and
1253 should match on name.
1254
1255 """
1256
1257 _ordered_columns: bool = True
1258 """
1259 if False, means we can't be sure the list of entries
1260 in _result_columns is actually the rendered order. Usually
1261 True unless using an unordered TextualSelect.
1262 """
1263
1264 _loose_column_name_matching: bool = False
1265 """tell the result object that the SQL statement is textual, wants to match
1266 up to Column objects, and may be using the ._tq_label in the SELECT rather
1267 than the base name.
1268
1269 """
1270
1271 _numeric_binds: bool = False
1272 """
1273 True if paramstyle is "numeric". This paramstyle is trickier than
1274 all the others.
1275
1276 """
1277
1278 _render_postcompile: bool = False
1279 """
1280 whether to render out POSTCOMPILE params during the compile phase.
1281
1282 This attribute is used only for end-user invocation of stmt.compile();
1283 it's never used for actual statement execution, where instead the
1284 dialect internals access and render the internal postcompile structure
1285 directly.
1286
1287 """
1288
1289 _post_compile_expanded_state: Optional[ExpandedState] = None
1290 """When render_postcompile is used, the ``ExpandedState`` used to create
1291 the "expanded" SQL is assigned here, and then used by the ``.params``
1292 accessor and ``.construct_params()`` methods for their return values.
1293
1294 .. versionadded:: 2.0.0rc1
1295
1296 """
1297
1298 _pre_expanded_string: Optional[str] = None
1299 """Stores the original string SQL before 'post_compile' is applied,
1300 for cases where 'post_compile' were used.
1301
1302 """
1303
1304 _pre_expanded_positiontup: Optional[List[str]] = None
1305
1306 _insertmanyvalues: Optional[_InsertManyValues] = None
1307
1308 _insert_crud_params: Optional[crud._CrudParamSequence] = None
1309
1310 literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
1311 """bindparameter objects that are rendered as literal values at statement
1312 execution time.
1313
1314 """
1315
1316 post_compile_params: FrozenSet[BindParameter[Any]] = frozenset()
1317 """bindparameter objects that are rendered as bound parameter placeholders
1318 at statement execution time.
1319
1320 """
1321
1322 escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT
1323 """Late escaping of bound parameter names that has to be converted
1324 to the original name when looking in the parameter dictionary.
1325
1326 """
1327
1328 has_out_parameters = False
1329 """if True, there are bindparam() objects that have the isoutparam
1330 flag set."""
1331
1332 postfetch_lastrowid = False
1333 """if True, and this in insert, use cursor.lastrowid to populate
1334 result.inserted_primary_key. """
1335
1336 _cache_key_bind_match: Optional[
1337 Tuple[
1338 Dict[
1339 BindParameter[Any],
1340 List[BindParameter[Any]],
1341 ],
1342 Dict[
1343 str,
1344 BindParameter[Any],
1345 ],
1346 ]
1347 ] = None
1348 """a mapping that will relate the BindParameter object we compile
1349 to those that are part of the extracted collection of parameters
1350 in the cache key, if we were given a cache key.
1351
1352 """
1353
1354 positiontup: Optional[List[str]] = None
1355 """for a compiled construct that uses a positional paramstyle, will be
1356 a sequence of strings, indicating the names of bound parameters in order.
1357
1358 This is used in order to render bound parameters in their correct order,
1359 and is combined with the :attr:`_sql.Compiled.params` dictionary to
1360 render parameters.
1361
1362 This sequence always contains the unescaped name of the parameters.
1363
1364 .. seealso::
1365
1366 :ref:`faq_sql_expression_string` - includes a usage example for
1367 debugging use cases.
1368
1369 """
1370 _values_bindparam: Optional[List[str]] = None
1371
1372 _visited_bindparam: Optional[List[str]] = None
1373
1374 inline: bool = False
1375
1376 ctes: Optional[MutableMapping[CTE, str]]
1377
1378 # Detect same CTE references - Dict[(level, name), cte]
1379 # Level is required for supporting nesting
1380 ctes_by_level_name: Dict[Tuple[int, str], CTE]
1381
1382 # To retrieve key/level in ctes_by_level_name -
1383 # Dict[cte_reference, (level, cte_name, cte_opts)]
1384 level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
1385
1386 ctes_recursive: bool
1387
1388 _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]")
1389 _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s")
1390 _positional_pattern = re.compile(
1391 f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
1392 )
1393
1394 @classmethod
1395 def _init_compiler_cls(cls):
1396 cls._init_bind_translate()
1397
1398 @classmethod
1399 def _init_bind_translate(cls):
1400 reg = re.escape("".join(cls.bindname_escape_characters))
1401 cls._bind_translate_re = re.compile(f"[{reg}]")
1402 cls._bind_translate_chars = cls.bindname_escape_characters
1403
1404 def __init__(
1405 self,
1406 dialect: Dialect,
1407 statement: Optional[ClauseElement],
1408 cache_key: Optional[CacheKey] = None,
1409 column_keys: Optional[Sequence[str]] = None,
1410 for_executemany: bool = False,
1411 linting: Linting = NO_LINTING,
1412 _supporting_against: Optional[SQLCompiler] = None,
1413 **kwargs: Any,
1414 ):
1415 """Construct a new :class:`.SQLCompiler` object.
1416
1417 :param dialect: :class:`.Dialect` to be used
1418
1419 :param statement: :class:`_expression.ClauseElement` to be compiled
1420
1421 :param column_keys: a list of column names to be compiled into an
1422 INSERT or UPDATE statement.
1423
1424 :param for_executemany: whether INSERT / UPDATE statements should
1425 expect that they are to be invoked in an "executemany" style,
1426 which may impact how the statement will be expected to return the
1427 values of defaults and autoincrement / sequences and similar.
1428 Depending on the backend and driver in use, support for retrieving
1429 these values may be disabled which means SQL expressions may
1430 be rendered inline, RETURNING may not be rendered, etc.
1431
1432 :param kwargs: additional keyword arguments to be consumed by the
1433 superclass.
1434
1435 """
1436 self.column_keys = column_keys
1437
1438 self.cache_key = cache_key
1439
1440 if cache_key:
1441 cksm = {b.key: b for b in cache_key[1]}
1442 ckbm = {b: [b] for b in cache_key[1]}
1443 self._cache_key_bind_match = (ckbm, cksm)
1444
1445 # compile INSERT/UPDATE defaults/sequences to expect executemany
1446 # style execution, which may mean no pre-execute of defaults,
1447 # or no RETURNING
1448 self.for_executemany = for_executemany
1449
1450 self.linting = linting
1451
1452 # a dictionary of bind parameter keys to BindParameter
1453 # instances.
1454 self.binds = {}
1455
1456 # a dictionary of BindParameter instances to "compiled" names
1457 # that are actually present in the generated SQL
1458 self.bind_names = util.column_dict()
1459
1460 # stack which keeps track of nested SELECT statements
1461 self.stack = []
1462
1463 self._result_columns = []
1464
1465 # true if the paramstyle is positional
1466 self.positional = dialect.positional
1467 if self.positional:
1468 self._numeric_binds = nb = dialect.paramstyle.startswith("numeric")
1469 if nb:
1470 self._numeric_binds_identifier_char = (
1471 "$" if dialect.paramstyle == "numeric_dollar" else ":"
1472 )
1473
1474 self.compilation_bindtemplate = _pyformat_template
1475 else:
1476 self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
1477
1478 self.ctes = None
1479
1480 self.label_length = (
1481 dialect.label_length or dialect.max_identifier_length
1482 )
1483
1484 # a map which tracks "anonymous" identifiers that are created on
1485 # the fly here
1486 self.anon_map = prefix_anon_map()
1487
1488 # a map which tracks "truncated" names based on
1489 # dialect.label_length or dialect.max_identifier_length
1490 self.truncated_names: Dict[Tuple[str, str], str] = {}
1491 self._truncated_counters: Dict[str, int] = {}
1492
1493 Compiled.__init__(self, dialect, statement, **kwargs)
1494
1495 if self.isinsert or self.isupdate or self.isdelete:
1496 if TYPE_CHECKING:
1497 assert isinstance(statement, UpdateBase)
1498
1499 if self.isinsert or self.isupdate:
1500 if TYPE_CHECKING:
1501 assert isinstance(statement, ValuesBase)
1502 if statement._inline:
1503 self.inline = True
1504 elif self.for_executemany and (
1505 not self.isinsert
1506 or (
1507 self.dialect.insert_executemany_returning
1508 and statement._return_defaults
1509 )
1510 ):
1511 self.inline = True
1512
1513 self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
1514
1515 if _supporting_against:
1516 self.__dict__.update(
1517 {
1518 k: v
1519 for k, v in _supporting_against.__dict__.items()
1520 if k
1521 not in {
1522 "state",
1523 "dialect",
1524 "preparer",
1525 "positional",
1526 "_numeric_binds",
1527 "compilation_bindtemplate",
1528 "bindtemplate",
1529 }
1530 }
1531 )
1532
1533 if self.state is CompilerState.STRING_APPLIED:
1534 if self.positional:
1535 if self._numeric_binds:
1536 self._process_numeric()
1537 else:
1538 self._process_positional()
1539
1540 if self._render_postcompile:
1541 parameters = self.construct_params(
1542 escape_names=False,
1543 _no_postcompile=True,
1544 )
1545
1546 self._process_parameters_for_postcompile(
1547 parameters, _populate_self=True
1548 )
1549
1550 @property
1551 def insert_single_values_expr(self) -> Optional[str]:
1552 """When an INSERT is compiled with a single set of parameters inside
1553 a VALUES expression, the string is assigned here, where it can be
1554 used for insert batching schemes to rewrite the VALUES expression.
1555
1556 .. versionchanged:: 2.0 This collection is no longer used by
1557 SQLAlchemy's built-in dialects, in favor of the currently
1558 internal ``_insertmanyvalues`` collection that is used only by
1559 :class:`.SQLCompiler`.
1560
1561 """
1562 if self._insertmanyvalues is None:
1563 return None
1564 else:
1565 return self._insertmanyvalues.single_values_expr
1566
1567 @util.ro_memoized_property
1568 def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]:
1569 """The effective "returning" columns for INSERT, UPDATE or DELETE.
1570
1571 This is either the so-called "implicit returning" columns which are
1572 calculated by the compiler on the fly, or those present based on what's
1573 present in ``self.statement._returning`` (expanded into individual
1574 columns using the ``._all_selected_columns`` attribute) i.e. those set
1575 explicitly using the :meth:`.UpdateBase.returning` method.
1576
1577 .. versionadded:: 2.0
1578
1579 """
1580 if self.implicit_returning:
1581 return self.implicit_returning
1582 elif self.statement is not None and is_dml(self.statement):
1583 return [
1584 c
1585 for c in self.statement._all_selected_columns
1586 if is_column_element(c)
1587 ]
1588
1589 else:
1590 return None
1591
1592 @property
1593 def returning(self):
1594 """backwards compatibility; returns the
1595 effective_returning collection.
1596
1597 """
1598 return self.effective_returning
1599
1600 @property
1601 def current_executable(self):
1602 """Return the current 'executable' that is being compiled.
1603
1604 This is currently the :class:`_sql.Select`, :class:`_sql.Insert`,
1605 :class:`_sql.Update`, :class:`_sql.Delete`,
1606 :class:`_sql.CompoundSelect` object that is being compiled.
1607 Specifically it's assigned to the ``self.stack`` list of elements.
1608
1609 When a statement like the above is being compiled, it normally
1610 is also assigned to the ``.statement`` attribute of the
1611 :class:`_sql.Compiler` object. However, all SQL constructs are
1612 ultimately nestable, and this attribute should never be consulted
1613 by a ``visit_`` method, as it is not guaranteed to be assigned
1614 nor guaranteed to correspond to the current statement being compiled.
1615
1616 """
1617 try:
1618 return self.stack[-1]["selectable"]
1619 except IndexError as ie:
1620 raise IndexError("Compiler does not have a stack entry") from ie
1621
1622 @property
1623 def prefetch(self):
1624 return list(self.insert_prefetch) + list(self.update_prefetch)
1625
1626 @util.memoized_property
1627 def _global_attributes(self) -> Dict[Any, Any]:
1628 return {}
1629
1630 @util.memoized_instancemethod
1631 def _init_cte_state(self) -> MutableMapping[CTE, str]:
1632 """Initialize collections related to CTEs only if
1633 a CTE is located, to save on the overhead of
1634 these collections otherwise.
1635
1636 """
1637 # collect CTEs to tack on top of a SELECT
1638 # To store the query to print - Dict[cte, text_query]
1639 ctes: MutableMapping[CTE, str] = util.OrderedDict()
1640 self.ctes = ctes
1641
1642 # Detect same CTE references - Dict[(level, name), cte]
1643 # Level is required for supporting nesting
1644 self.ctes_by_level_name = {}
1645
1646 # To retrieve key/level in ctes_by_level_name -
1647 # Dict[cte_reference, (level, cte_name, cte_opts)]
1648 self.level_name_by_cte = {}
1649
1650 self.ctes_recursive = False
1651
1652 return ctes
1653
1654 @contextlib.contextmanager
1655 def _nested_result(self):
1656 """special API to support the use case of 'nested result sets'"""
1657 result_columns, ordered_columns = (
1658 self._result_columns,
1659 self._ordered_columns,
1660 )
1661 self._result_columns, self._ordered_columns = [], False
1662
1663 try:
1664 if self.stack:
1665 entry = self.stack[-1]
1666 entry["need_result_map_for_nested"] = True
1667 else:
1668 entry = None
1669 yield self._result_columns, self._ordered_columns
1670 finally:
1671 if entry:
1672 entry.pop("need_result_map_for_nested")
1673 self._result_columns, self._ordered_columns = (
1674 result_columns,
1675 ordered_columns,
1676 )
1677
1678 def _process_positional(self):
1679 assert not self.positiontup
1680 assert self.state is CompilerState.STRING_APPLIED
1681 assert not self._numeric_binds
1682
1683 if self.dialect.paramstyle == "format":
1684 placeholder = "%s"
1685 else:
1686 assert self.dialect.paramstyle == "qmark"
1687 placeholder = "?"
1688
1689 positions = []
1690
1691 def find_position(m: re.Match[str]) -> str:
1692 normal_bind = m.group(1)
1693 if normal_bind:
1694 positions.append(normal_bind)
1695 return placeholder
1696 else:
1697 # this a post-compile bind
1698 positions.append(m.group(2))
1699 return m.group(0)
1700
1701 self.string = re.sub(
1702 self._positional_pattern, find_position, self.string
1703 )
1704
1705 if self.escaped_bind_names:
1706 reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
1707 assert len(self.escaped_bind_names) == len(reverse_escape)
1708 self.positiontup = [
1709 reverse_escape.get(name, name) for name in positions
1710 ]
1711 else:
1712 self.positiontup = positions
1713
1714 if self._insertmanyvalues:
1715 positions = []
1716
1717 single_values_expr = re.sub(
1718 self._positional_pattern,
1719 find_position,
1720 self._insertmanyvalues.single_values_expr,
1721 )
1722 insert_crud_params = [
1723 (
1724 v[0],
1725 v[1],
1726 re.sub(self._positional_pattern, find_position, v[2]),
1727 v[3],
1728 )
1729 for v in self._insertmanyvalues.insert_crud_params
1730 ]
1731
1732 self._insertmanyvalues = self._insertmanyvalues._replace(
1733 single_values_expr=single_values_expr,
1734 insert_crud_params=insert_crud_params,
1735 )
1736
1737 def _process_numeric(self):
1738 assert self._numeric_binds
1739 assert self.state is CompilerState.STRING_APPLIED
1740
1741 num = 1
1742 param_pos: Dict[str, str] = {}
1743 order: Iterable[str]
1744 if self._insertmanyvalues and self._values_bindparam is not None:
1745 # bindparams that are not in values are always placed first.
1746 # this avoids the need of changing them when using executemany
1747 # values () ()
1748 order = itertools.chain(
1749 (
1750 name
1751 for name in self.bind_names.values()
1752 if name not in self._values_bindparam
1753 ),
1754 self.bind_names.values(),
1755 )
1756 else:
1757 order = self.bind_names.values()
1758
1759 for bind_name in order:
1760 if bind_name in param_pos:
1761 continue
1762 bind = self.binds[bind_name]
1763 if (
1764 bind in self.post_compile_params
1765 or bind in self.literal_execute_params
1766 ):
1767 # set to None to just mark the in positiontup, it will not
1768 # be replaced below.
1769 param_pos[bind_name] = None # type: ignore
1770 else:
1771 ph = f"{self._numeric_binds_identifier_char}{num}"
1772 num += 1
1773 param_pos[bind_name] = ph
1774
1775 self.next_numeric_pos = num
1776
1777 self.positiontup = list(param_pos)
1778 if self.escaped_bind_names:
1779 len_before = len(param_pos)
1780 param_pos = {
1781 self.escaped_bind_names.get(name, name): pos
1782 for name, pos in param_pos.items()
1783 }
1784 assert len(param_pos) == len_before
1785
1786 # Can't use format here since % chars are not escaped.
1787 self.string = self._pyformat_pattern.sub(
1788 lambda m: param_pos[m.group(1)], self.string
1789 )
1790
1791 if self._insertmanyvalues:
1792 single_values_expr = (
1793 # format is ok here since single_values_expr includes only
1794 # place-holders
1795 self._insertmanyvalues.single_values_expr
1796 % param_pos
1797 )
1798 insert_crud_params = [
1799 (v[0], v[1], "%s", v[3])
1800 for v in self._insertmanyvalues.insert_crud_params
1801 ]
1802
1803 self._insertmanyvalues = self._insertmanyvalues._replace(
1804 # This has the numbers (:1, :2)
1805 single_values_expr=single_values_expr,
1806 # The single binds are instead %s so they can be formatted
1807 insert_crud_params=insert_crud_params,
1808 )
1809
1810 @util.memoized_property
1811 def _bind_processors(
1812 self,
1813 ) -> MutableMapping[
1814 str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]]
1815 ]:
1816 # mypy is not able to see the two value types as the above Union,
1817 # it just sees "object". don't know how to resolve
1818 return {
1819 key: value # type: ignore
1820 for key, value in (
1821 (
1822 self.bind_names[bindparam],
1823 (
1824 bindparam.type._cached_bind_processor(self.dialect)
1825 if not bindparam.type._is_tuple_type
1826 else tuple(
1827 elem_type._cached_bind_processor(self.dialect)
1828 for elem_type in cast(
1829 TupleType, bindparam.type
1830 ).types
1831 )
1832 ),
1833 )
1834 for bindparam in self.bind_names
1835 )
1836 if value is not None
1837 }
1838
1839 def is_subquery(self):
1840 return len(self.stack) > 1
1841
1842 @property
1843 def sql_compiler(self) -> Self:
1844 return self
1845
1846 def construct_expanded_state(
1847 self,
1848 params: Optional[_CoreSingleExecuteParams] = None,
1849 escape_names: bool = True,
1850 ) -> ExpandedState:
1851 """Return a new :class:`.ExpandedState` for a given parameter set.
1852
1853 For queries that use "expanding" or other late-rendered parameters,
1854 this method will provide for both the finalized SQL string as well
1855 as the parameters that would be used for a particular parameter set.
1856
1857 .. versionadded:: 2.0.0rc1
1858
1859 """
1860 parameters = self.construct_params(
1861 params,
1862 escape_names=escape_names,
1863 _no_postcompile=True,
1864 )
1865 return self._process_parameters_for_postcompile(
1866 parameters,
1867 )
1868
1869 def construct_params(
1870 self,
1871 params: Optional[_CoreSingleExecuteParams] = None,
1872 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
1873 escape_names: bool = True,
1874 _group_number: Optional[int] = None,
1875 _check: bool = True,
1876 _no_postcompile: bool = False,
1877 ) -> _MutableCoreSingleExecuteParams:
1878 """return a dictionary of bind parameter keys and values"""
1879
1880 if self._render_postcompile and not _no_postcompile:
1881 assert self._post_compile_expanded_state is not None
1882 if not params:
1883 return dict(self._post_compile_expanded_state.parameters)
1884 else:
1885 raise exc.InvalidRequestError(
1886 "can't construct new parameters when render_postcompile "
1887 "is used; the statement is hard-linked to the original "
1888 "parameters. Use construct_expanded_state to generate a "
1889 "new statement and parameters."
1890 )
1891
1892 has_escaped_names = escape_names and bool(self.escaped_bind_names)
1893
1894 if extracted_parameters:
1895 # related the bound parameters collected in the original cache key
1896 # to those collected in the incoming cache key. They will not have
1897 # matching names but they will line up positionally in the same
1898 # way. The parameters present in self.bind_names may be clones of
1899 # these original cache key params in the case of DML but the .key
1900 # will be guaranteed to match.
1901 if self.cache_key is None:
1902 raise exc.CompileError(
1903 "This compiled object has no original cache key; "
1904 "can't pass extracted_parameters to construct_params"
1905 )
1906 else:
1907 orig_extracted = self.cache_key[1]
1908
1909 ckbm_tuple = self._cache_key_bind_match
1910 assert ckbm_tuple is not None
1911 ckbm, _ = ckbm_tuple
1912 resolved_extracted = {
1913 bind: extracted
1914 for b, extracted in zip(orig_extracted, extracted_parameters)
1915 for bind in ckbm[b]
1916 }
1917 else:
1918 resolved_extracted = None
1919
1920 if params:
1921 pd = {}
1922 for bindparam, name in self.bind_names.items():
1923 escaped_name = (
1924 self.escaped_bind_names.get(name, name)
1925 if has_escaped_names
1926 else name
1927 )
1928
1929 if bindparam.key in params:
1930 pd[escaped_name] = params[bindparam.key]
1931 elif name in params:
1932 pd[escaped_name] = params[name]
1933
1934 elif _check and bindparam.required:
1935 if _group_number:
1936 raise exc.InvalidRequestError(
1937 "A value is required for bind parameter %r, "
1938 "in parameter group %d"
1939 % (bindparam.key, _group_number),
1940 code="cd3x",
1941 )
1942 else:
1943 raise exc.InvalidRequestError(
1944 "A value is required for bind parameter %r"
1945 % bindparam.key,
1946 code="cd3x",
1947 )
1948 else:
1949 if resolved_extracted:
1950 value_param = resolved_extracted.get(
1951 bindparam, bindparam
1952 )
1953 else:
1954 value_param = bindparam
1955
1956 if bindparam.callable:
1957 pd[escaped_name] = value_param.effective_value
1958 else:
1959 pd[escaped_name] = value_param.value
1960 return pd
1961 else:
1962 pd = {}
1963 for bindparam, name in self.bind_names.items():
1964 escaped_name = (
1965 self.escaped_bind_names.get(name, name)
1966 if has_escaped_names
1967 else name
1968 )
1969
1970 if _check and bindparam.required:
1971 if _group_number:
1972 raise exc.InvalidRequestError(
1973 "A value is required for bind parameter %r, "
1974 "in parameter group %d"
1975 % (bindparam.key, _group_number),
1976 code="cd3x",
1977 )
1978 else:
1979 raise exc.InvalidRequestError(
1980 "A value is required for bind parameter %r"
1981 % bindparam.key,
1982 code="cd3x",
1983 )
1984
1985 if resolved_extracted:
1986 value_param = resolved_extracted.get(bindparam, bindparam)
1987 else:
1988 value_param = bindparam
1989
1990 if bindparam.callable:
1991 pd[escaped_name] = value_param.effective_value
1992 else:
1993 pd[escaped_name] = value_param.value
1994
1995 return pd
1996
1997 @util.memoized_instancemethod
1998 def _get_set_input_sizes_lookup(self):
1999 dialect = self.dialect
2000
2001 include_types = dialect.include_set_input_sizes
2002 exclude_types = dialect.exclude_set_input_sizes
2003
2004 dbapi = dialect.dbapi
2005
2006 def lookup_type(typ):
2007 dbtype = typ._unwrapped_dialect_impl(dialect).get_dbapi_type(dbapi)
2008
2009 if (
2010 dbtype is not None
2011 and (exclude_types is None or dbtype not in exclude_types)
2012 and (include_types is None or dbtype in include_types)
2013 ):
2014 return dbtype
2015 else:
2016 return None
2017
2018 inputsizes = {}
2019
2020 literal_execute_params = self.literal_execute_params
2021
2022 for bindparam in self.bind_names:
2023 if bindparam in literal_execute_params:
2024 continue
2025
2026 if bindparam.type._is_tuple_type:
2027 inputsizes[bindparam] = [
2028 lookup_type(typ)
2029 for typ in cast(TupleType, bindparam.type).types
2030 ]
2031 else:
2032 inputsizes[bindparam] = lookup_type(bindparam.type)
2033
2034 return inputsizes
2035
2036 @property
2037 def params(self):
2038 """Return the bind param dictionary embedded into this
2039 compiled object, for those values that are present.
2040
2041 .. seealso::
2042
2043 :ref:`faq_sql_expression_string` - includes a usage example for
2044 debugging use cases.
2045
2046 """
2047 return self.construct_params(_check=False)
2048
2049 def _process_parameters_for_postcompile(
2050 self,
2051 parameters: _MutableCoreSingleExecuteParams,
2052 _populate_self: bool = False,
2053 ) -> ExpandedState:
2054 """handle special post compile parameters.
2055
2056 These include:
2057
2058 * "expanding" parameters -typically IN tuples that are rendered
2059 on a per-parameter basis for an otherwise fixed SQL statement string.
2060
2061 * literal_binds compiled with the literal_execute flag. Used for
2062 things like SQL Server "TOP N" where the driver does not accommodate
2063 N as a bound parameter.
2064
2065 """
2066
2067 expanded_parameters = {}
2068 new_positiontup: Optional[List[str]]
2069
2070 pre_expanded_string = self._pre_expanded_string
2071 if pre_expanded_string is None:
2072 pre_expanded_string = self.string
2073
2074 if self.positional:
2075 new_positiontup = []
2076
2077 pre_expanded_positiontup = self._pre_expanded_positiontup
2078 if pre_expanded_positiontup is None:
2079 pre_expanded_positiontup = self.positiontup
2080
2081 else:
2082 new_positiontup = pre_expanded_positiontup = None
2083
2084 processors = self._bind_processors
2085 single_processors = cast(
2086 "Mapping[str, _BindProcessorType[Any]]", processors
2087 )
2088 tuple_processors = cast(
2089 "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors
2090 )
2091
2092 new_processors: Dict[str, _BindProcessorType[Any]] = {}
2093
2094 replacement_expressions: Dict[str, Any] = {}
2095 to_update_sets: Dict[str, Any] = {}
2096
2097 # notes:
2098 # *unescaped* parameter names in:
2099 # self.bind_names, self.binds, self._bind_processors, self.positiontup
2100 #
2101 # *escaped* parameter names in:
2102 # construct_params(), replacement_expressions
2103
2104 numeric_positiontup: Optional[List[str]] = None
2105
2106 if self.positional and pre_expanded_positiontup is not None:
2107 names: Iterable[str] = pre_expanded_positiontup
2108 if self._numeric_binds:
2109 numeric_positiontup = []
2110 else:
2111 names = self.bind_names.values()
2112
2113 ebn = self.escaped_bind_names
2114 for name in names:
2115 escaped_name = ebn.get(name, name) if ebn else name
2116 parameter = self.binds[name]
2117
2118 if parameter in self.literal_execute_params:
2119 if escaped_name not in replacement_expressions:
2120 replacement_expressions[escaped_name] = (
2121 self.render_literal_bindparam(
2122 parameter,
2123 render_literal_value=parameters.pop(escaped_name),
2124 )
2125 )
2126 continue
2127
2128 if parameter in self.post_compile_params:
2129 if escaped_name in replacement_expressions:
2130 to_update = to_update_sets[escaped_name]
2131 values = None
2132 else:
2133 # we are removing the parameter from parameters
2134 # because it is a list value, which is not expected by
2135 # TypeEngine objects that would otherwise be asked to
2136 # process it. the single name is being replaced with
2137 # individual numbered parameters for each value in the
2138 # param.
2139 #
2140 # note we are also inserting *escaped* parameter names
2141 # into the given dictionary. default dialect will
2142 # use these param names directly as they will not be
2143 # in the escaped_bind_names dictionary.
2144 values = parameters.pop(name)
2145
2146 leep_res = self._literal_execute_expanding_parameter(
2147 escaped_name, parameter, values
2148 )
2149 (to_update, replacement_expr) = leep_res
2150
2151 to_update_sets[escaped_name] = to_update
2152 replacement_expressions[escaped_name] = replacement_expr
2153
2154 if not parameter.literal_execute:
2155 parameters.update(to_update)
2156 if parameter.type._is_tuple_type:
2157 assert values is not None
2158 new_processors.update(
2159 (
2160 "%s_%s_%s" % (name, i, j),
2161 tuple_processors[name][j - 1],
2162 )
2163 for i, tuple_element in enumerate(values, 1)
2164 for j, _ in enumerate(tuple_element, 1)
2165 if name in tuple_processors
2166 and tuple_processors[name][j - 1] is not None
2167 )
2168 else:
2169 new_processors.update(
2170 (key, single_processors[name])
2171 for key, _ in to_update
2172 if name in single_processors
2173 )
2174 if numeric_positiontup is not None:
2175 numeric_positiontup.extend(
2176 name for name, _ in to_update
2177 )
2178 elif new_positiontup is not None:
2179 # to_update has escaped names, but that's ok since
2180 # these are new names, that aren't in the
2181 # escaped_bind_names dict.
2182 new_positiontup.extend(name for name, _ in to_update)
2183 expanded_parameters[name] = [
2184 expand_key for expand_key, _ in to_update
2185 ]
2186 elif new_positiontup is not None:
2187 new_positiontup.append(name)
2188
2189 def process_expanding(m):
2190 key = m.group(1)
2191 expr = replacement_expressions[key]
2192
2193 # if POSTCOMPILE included a bind_expression, render that
2194 # around each element
2195 if m.group(2):
2196 tok = m.group(2).split("~~")
2197 be_left, be_right = tok[1], tok[3]
2198 expr = ", ".join(
2199 "%s%s%s" % (be_left, exp, be_right)
2200 for exp in expr.split(", ")
2201 )
2202 return expr
2203
2204 statement = re.sub(
2205 self._post_compile_pattern, process_expanding, pre_expanded_string
2206 )
2207
2208 if numeric_positiontup is not None:
2209 assert new_positiontup is not None
2210 param_pos = {
2211 key: f"{self._numeric_binds_identifier_char}{num}"
2212 for num, key in enumerate(
2213 numeric_positiontup, self.next_numeric_pos
2214 )
2215 }
2216 # Can't use format here since % chars are not escaped.
2217 statement = self._pyformat_pattern.sub(
2218 lambda m: param_pos[m.group(1)], statement
2219 )
2220 new_positiontup.extend(numeric_positiontup)
2221
2222 expanded_state = ExpandedState(
2223 statement,
2224 parameters,
2225 new_processors,
2226 new_positiontup,
2227 expanded_parameters,
2228 )
2229
2230 if _populate_self:
2231 # this is for the "render_postcompile" flag, which is not
2232 # otherwise used internally and is for end-user debugging and
2233 # special use cases.
2234 self._pre_expanded_string = pre_expanded_string
2235 self._pre_expanded_positiontup = pre_expanded_positiontup
2236 self.string = expanded_state.statement
2237 self.positiontup = (
2238 list(expanded_state.positiontup or ())
2239 if self.positional
2240 else None
2241 )
2242 self._post_compile_expanded_state = expanded_state
2243
2244 return expanded_state
2245
2246 @util.preload_module("sqlalchemy.engine.cursor")
2247 def _create_result_map(self):
2248 """utility method used for unit tests only."""
2249 cursor = util.preloaded.engine_cursor
2250 return cursor.CursorResultMetaData._create_description_match_map(
2251 self._result_columns
2252 )
2253
2254 # assigned by crud.py for insert/update statements
2255 _get_bind_name_for_col: _BindNameForColProtocol
2256
2257 @util.memoized_property
2258 def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
2259 getter = self._get_bind_name_for_col
2260 return getter
2261
2262 @util.memoized_property
2263 @util.preload_module("sqlalchemy.engine.result")
2264 def _inserted_primary_key_from_lastrowid_getter(self):
2265 result = util.preloaded.engine_result
2266
2267 param_key_getter = self._within_exec_param_key_getter
2268
2269 assert self.compile_state is not None
2270 statement = self.compile_state.statement
2271
2272 if TYPE_CHECKING:
2273 assert isinstance(statement, Insert)
2274
2275 table = statement.table
2276
2277 getters = [
2278 (operator.methodcaller("get", param_key_getter(col), None), col)
2279 for col in table.primary_key
2280 ]
2281
2282 autoinc_getter = None
2283 autoinc_col = table._autoincrement_column
2284 if autoinc_col is not None:
2285 # apply type post processors to the lastrowid
2286 lastrowid_processor = autoinc_col.type._cached_result_processor(
2287 self.dialect, None
2288 )
2289 autoinc_key = param_key_getter(autoinc_col)
2290
2291 # if a bind value is present for the autoincrement column
2292 # in the parameters, we need to do the logic dictated by
2293 # #7998; honor a non-None user-passed parameter over lastrowid.
2294 # previously in the 1.4 series we weren't fetching lastrowid
2295 # at all if the key were present in the parameters
2296 if autoinc_key in self.binds:
2297
2298 def _autoinc_getter(lastrowid, parameters):
2299 param_value = parameters.get(autoinc_key, lastrowid)
2300 if param_value is not None:
2301 # they supplied non-None parameter, use that.
2302 # SQLite at least is observed to return the wrong
2303 # cursor.lastrowid for INSERT..ON CONFLICT so it
2304 # can't be used in all cases
2305 return param_value
2306 else:
2307 # use lastrowid
2308 return lastrowid
2309
2310 # work around mypy https://github.com/python/mypy/issues/14027
2311 autoinc_getter = _autoinc_getter
2312
2313 else:
2314 lastrowid_processor = None
2315
2316 row_fn = result.result_tuple([col.key for col in table.primary_key])
2317
2318 def get(lastrowid, parameters):
2319 """given cursor.lastrowid value and the parameters used for INSERT,
2320 return a "row" that represents the primary key, either by
2321 using the "lastrowid" or by extracting values from the parameters
2322 that were sent along with the INSERT.
2323
2324 """
2325 if lastrowid_processor is not None:
2326 lastrowid = lastrowid_processor(lastrowid)
2327
2328 if lastrowid is None:
2329 return row_fn(getter(parameters) for getter, col in getters)
2330 else:
2331 return row_fn(
2332 (
2333 (
2334 autoinc_getter(lastrowid, parameters)
2335 if autoinc_getter is not None
2336 else lastrowid
2337 )
2338 if col is autoinc_col
2339 else getter(parameters)
2340 )
2341 for getter, col in getters
2342 )
2343
2344 return get
2345
2346 @util.memoized_property
2347 @util.preload_module("sqlalchemy.engine.result")
2348 def _inserted_primary_key_from_returning_getter(self):
2349 result = util.preloaded.engine_result
2350
2351 assert self.compile_state is not None
2352 statement = self.compile_state.statement
2353
2354 if TYPE_CHECKING:
2355 assert isinstance(statement, Insert)
2356
2357 param_key_getter = self._within_exec_param_key_getter
2358 table = statement.table
2359
2360 returning = self.implicit_returning
2361 assert returning is not None
2362 ret = {col: idx for idx, col in enumerate(returning)}
2363
2364 getters = cast(
2365 "List[Tuple[Callable[[Any], Any], bool]]",
2366 [
2367 (
2368 (operator.itemgetter(ret[col]), True)
2369 if col in ret
2370 else (
2371 operator.methodcaller(
2372 "get", param_key_getter(col), None
2373 ),
2374 False,
2375 )
2376 )
2377 for col in table.primary_key
2378 ],
2379 )
2380
2381 row_fn = result.result_tuple([col.key for col in table.primary_key])
2382
2383 def get(row, parameters):
2384 return row_fn(
2385 getter(row) if use_row else getter(parameters)
2386 for getter, use_row in getters
2387 )
2388
2389 return get
2390
2391 def default_from(self) -> str:
2392 """Called when a SELECT statement has no froms, and no FROM clause is
2393 to be appended.
2394
2395 Gives Oracle Database a chance to tack on a ``FROM DUAL`` to the string
2396 output.
2397
2398 """
2399 return ""
2400
2401 def visit_override_binds(self, override_binds, **kw):
2402 """SQL compile the nested element of an _OverrideBinds with
2403 bindparams swapped out.
2404
2405 The _OverrideBinds is not normally expected to be compiled; it
2406 is meant to be used when an already cached statement is to be used,
2407 the compilation was already performed, and only the bound params should
2408 be swapped in at execution time.
2409
2410 However, there are test cases that exericise this object, and
2411 additionally the ORM subquery loader is known to feed in expressions
2412 which include this construct into new queries (discovered in #11173),
2413 so it has to do the right thing at compile time as well.
2414
2415 """
2416
2417 # get SQL text first
2418 sqltext = override_binds.element._compiler_dispatch(self, **kw)
2419
2420 # for a test compile that is not for caching, change binds after the
2421 # fact. note that we don't try to
2422 # swap the bindparam as we compile, because our element may be
2423 # elsewhere in the statement already (e.g. a subquery or perhaps a
2424 # CTE) and was already visited / compiled. See
2425 # test_relationship_criteria.py ->
2426 # test_selectinload_local_criteria_subquery
2427 for k in override_binds.translate:
2428 if k not in self.binds:
2429 continue
2430 bp = self.binds[k]
2431
2432 # so this would work, just change the value of bp in place.
2433 # but we dont want to mutate things outside.
2434 # bp.value = override_binds.translate[bp.key]
2435 # continue
2436
2437 # instead, need to replace bp with new_bp or otherwise accommodate
2438 # in all internal collections
2439 new_bp = bp._with_value(
2440 override_binds.translate[bp.key],
2441 maintain_key=True,
2442 required=False,
2443 )
2444
2445 name = self.bind_names[bp]
2446 self.binds[k] = self.binds[name] = new_bp
2447 self.bind_names[new_bp] = name
2448 self.bind_names.pop(bp, None)
2449
2450 if bp in self.post_compile_params:
2451 self.post_compile_params |= {new_bp}
2452 if bp in self.literal_execute_params:
2453 self.literal_execute_params |= {new_bp}
2454
2455 ckbm_tuple = self._cache_key_bind_match
2456 if ckbm_tuple:
2457 ckbm, cksm = ckbm_tuple
2458 for bp in bp._cloned_set:
2459 if bp.key in cksm:
2460 cb = cksm[bp.key]
2461 ckbm[cb].append(new_bp)
2462
2463 return sqltext
2464
2465 def visit_grouping(self, grouping, asfrom=False, **kwargs):
2466 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
2467
2468 def visit_select_statement_grouping(self, grouping, **kwargs):
2469 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
2470
2471 def visit_label_reference(
2472 self, element, within_columns_clause=False, **kwargs
2473 ):
2474 if self.stack and self.dialect.supports_simple_order_by_label:
2475 try:
2476 compile_state = cast(
2477 "Union[SelectState, CompoundSelectState]",
2478 self.stack[-1]["compile_state"],
2479 )
2480 except KeyError as ke:
2481 raise exc.CompileError(
2482 "Can't resolve label reference for ORDER BY / "
2483 "GROUP BY / DISTINCT etc."
2484 ) from ke
2485
2486 (
2487 with_cols,
2488 only_froms,
2489 only_cols,
2490 ) = compile_state._label_resolve_dict
2491 if within_columns_clause:
2492 resolve_dict = only_froms
2493 else:
2494 resolve_dict = only_cols
2495
2496 # this can be None in the case that a _label_reference()
2497 # were subject to a replacement operation, in which case
2498 # the replacement of the Label element may have changed
2499 # to something else like a ColumnClause expression.
2500 order_by_elem = element.element._order_by_label_element
2501
2502 if (
2503 order_by_elem is not None
2504 and order_by_elem.name in resolve_dict
2505 and order_by_elem.shares_lineage(
2506 resolve_dict[order_by_elem.name]
2507 )
2508 ):
2509 kwargs["render_label_as_label"] = (
2510 element.element._order_by_label_element
2511 )
2512 return self.process(
2513 element.element,
2514 within_columns_clause=within_columns_clause,
2515 **kwargs,
2516 )
2517
2518 def visit_textual_label_reference(
2519 self, element, within_columns_clause=False, **kwargs
2520 ):
2521 if not self.stack:
2522 # compiling the element outside of the context of a SELECT
2523 return self.process(element._text_clause)
2524
2525 try:
2526 compile_state = cast(
2527 "Union[SelectState, CompoundSelectState]",
2528 self.stack[-1]["compile_state"],
2529 )
2530 except KeyError as ke:
2531 coercions._no_text_coercion(
2532 element.element,
2533 extra=(
2534 "Can't resolve label reference for ORDER BY / "
2535 "GROUP BY / DISTINCT etc."
2536 ),
2537 exc_cls=exc.CompileError,
2538 err=ke,
2539 )
2540
2541 with_cols, only_froms, only_cols = compile_state._label_resolve_dict
2542 try:
2543 if within_columns_clause:
2544 col = only_froms[element.element]
2545 else:
2546 col = with_cols[element.element]
2547 except KeyError as err:
2548 coercions._no_text_coercion(
2549 element.element,
2550 extra=(
2551 "Can't resolve label reference for ORDER BY / "
2552 "GROUP BY / DISTINCT etc."
2553 ),
2554 exc_cls=exc.CompileError,
2555 err=err,
2556 )
2557 else:
2558 kwargs["render_label_as_label"] = col
2559 return self.process(
2560 col, within_columns_clause=within_columns_clause, **kwargs
2561 )
2562
2563 def visit_label(
2564 self,
2565 label,
2566 add_to_result_map=None,
2567 within_label_clause=False,
2568 within_columns_clause=False,
2569 render_label_as_label=None,
2570 result_map_targets=(),
2571 **kw,
2572 ):
2573 # only render labels within the columns clause
2574 # or ORDER BY clause of a select. dialect-specific compilers
2575 # can modify this behavior.
2576 render_label_with_as = (
2577 within_columns_clause and not within_label_clause
2578 )
2579 render_label_only = render_label_as_label is label
2580
2581 if render_label_only or render_label_with_as:
2582 if isinstance(label.name, elements._truncated_label):
2583 labelname = self._truncated_identifier("colident", label.name)
2584 else:
2585 labelname = label.name
2586
2587 if render_label_with_as:
2588 if add_to_result_map is not None:
2589 add_to_result_map(
2590 labelname,
2591 label.name,
2592 (label, labelname) + label._alt_names + result_map_targets,
2593 label.type,
2594 )
2595 return (
2596 label.element._compiler_dispatch(
2597 self,
2598 within_columns_clause=True,
2599 within_label_clause=True,
2600 **kw,
2601 )
2602 + OPERATORS[operators.as_]
2603 + self.preparer.format_label(label, labelname)
2604 )
2605 elif render_label_only:
2606 return self.preparer.format_label(label, labelname)
2607 else:
2608 return label.element._compiler_dispatch(
2609 self, within_columns_clause=False, **kw
2610 )
2611
2612 def _fallback_column_name(self, column):
2613 raise exc.CompileError(
2614 "Cannot compile Column object until its 'name' is assigned."
2615 )
2616
2617 def visit_lambda_element(self, element, **kw):
2618 sql_element = element._resolved
2619 return self.process(sql_element, **kw)
2620
2621 def visit_column(
2622 self,
2623 column: ColumnClause[Any],
2624 add_to_result_map: Optional[_ResultMapAppender] = None,
2625 include_table: bool = True,
2626 result_map_targets: Tuple[Any, ...] = (),
2627 ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
2628 **kwargs: Any,
2629 ) -> str:
2630 name = orig_name = column.name
2631 if name is None:
2632 name = self._fallback_column_name(column)
2633
2634 is_literal = column.is_literal
2635 if not is_literal and isinstance(name, elements._truncated_label):
2636 name = self._truncated_identifier("colident", name)
2637
2638 if add_to_result_map is not None:
2639 targets = (column, name, column.key) + result_map_targets
2640 if column._tq_label:
2641 targets += (column._tq_label,)
2642
2643 add_to_result_map(name, orig_name, targets, column.type)
2644
2645 if is_literal:
2646 # note we are not currently accommodating for
2647 # literal_column(quoted_name('ident', True)) here
2648 name = self.escape_literal_column(name)
2649 else:
2650 name = self.preparer.quote(name)
2651 table = column.table
2652 if table is None or not include_table or not table.named_with_column:
2653 return name
2654 else:
2655 effective_schema = self.preparer.schema_for_object(table)
2656
2657 if effective_schema:
2658 schema_prefix = (
2659 self.preparer.quote_schema(effective_schema) + "."
2660 )
2661 else:
2662 schema_prefix = ""
2663
2664 if TYPE_CHECKING:
2665 assert isinstance(table, NamedFromClause)
2666 tablename = table.name
2667
2668 if (
2669 not effective_schema
2670 and ambiguous_table_name_map
2671 and tablename in ambiguous_table_name_map
2672 ):
2673 tablename = ambiguous_table_name_map[tablename]
2674
2675 if isinstance(tablename, elements._truncated_label):
2676 tablename = self._truncated_identifier("alias", tablename)
2677
2678 return schema_prefix + self.preparer.quote(tablename) + "." + name
2679
2680 def visit_collation(self, element, **kw):
2681 return self.preparer.format_collation(element.collation)
2682
2683 def visit_fromclause(self, fromclause, **kwargs):
2684 return fromclause.name
2685
2686 def visit_index(self, index, **kwargs):
2687 return index.name
2688
2689 def visit_typeclause(self, typeclause, **kw):
2690 kw["type_expression"] = typeclause
2691 kw["identifier_preparer"] = self.preparer
2692 return self.dialect.type_compiler_instance.process(
2693 typeclause.type, **kw
2694 )
2695
2696 def post_process_text(self, text):
2697 if self.preparer._double_percents:
2698 text = text.replace("%", "%%")
2699 return text
2700
2701 def escape_literal_column(self, text):
2702 if self.preparer._double_percents:
2703 text = text.replace("%", "%%")
2704 return text
2705
2706 def visit_textclause(self, textclause, add_to_result_map=None, **kw):
2707 def do_bindparam(m):
2708 name = m.group(1)
2709 if name in textclause._bindparams:
2710 return self.process(textclause._bindparams[name], **kw)
2711 else:
2712 return self.bindparam_string(name, **kw)
2713
2714 if not self.stack:
2715 self.isplaintext = True
2716
2717 if add_to_result_map:
2718 # text() object is present in the columns clause of a
2719 # select(). Add a no-name entry to the result map so that
2720 # row[text()] produces a result
2721 add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
2722
2723 # un-escape any \:params
2724 return BIND_PARAMS_ESC.sub(
2725 lambda m: m.group(1),
2726 BIND_PARAMS.sub(
2727 do_bindparam, self.post_process_text(textclause.text)
2728 ),
2729 )
2730
2731 def visit_textual_select(
2732 self, taf, compound_index=None, asfrom=False, **kw
2733 ):
2734 toplevel = not self.stack
2735 entry = self._default_stack_entry if toplevel else self.stack[-1]
2736
2737 new_entry: _CompilerStackEntry = {
2738 "correlate_froms": set(),
2739 "asfrom_froms": set(),
2740 "selectable": taf,
2741 }
2742 self.stack.append(new_entry)
2743
2744 if taf._independent_ctes:
2745 self._dispatch_independent_ctes(taf, kw)
2746
2747 populate_result_map = (
2748 toplevel
2749 or (
2750 compound_index == 0
2751 and entry.get("need_result_map_for_compound", False)
2752 )
2753 or entry.get("need_result_map_for_nested", False)
2754 )
2755
2756 if populate_result_map:
2757 self._ordered_columns = self._textual_ordered_columns = (
2758 taf.positional
2759 )
2760
2761 # enable looser result column matching when the SQL text links to
2762 # Column objects by name only
2763 self._loose_column_name_matching = not taf.positional and bool(
2764 taf.column_args
2765 )
2766
2767 for c in taf.column_args:
2768 self.process(
2769 c,
2770 within_columns_clause=True,
2771 add_to_result_map=self._add_to_result_map,
2772 )
2773
2774 text = self.process(taf.element, **kw)
2775 if self.ctes:
2776 nesting_level = len(self.stack) if not toplevel else None
2777 text = self._render_cte_clause(nesting_level=nesting_level) + text
2778
2779 self.stack.pop(-1)
2780
2781 return text
2782
2783 def visit_null(self, expr: Null, **kw: Any) -> str:
2784 return "NULL"
2785
2786 def visit_true(self, expr: True_, **kw: Any) -> str:
2787 if self.dialect.supports_native_boolean:
2788 return "true"
2789 else:
2790 return "1"
2791
2792 def visit_false(self, expr: False_, **kw: Any) -> str:
2793 if self.dialect.supports_native_boolean:
2794 return "false"
2795 else:
2796 return "0"
2797
2798 def _generate_delimited_list(self, elements, separator, **kw):
2799 return separator.join(
2800 s
2801 for s in (c._compiler_dispatch(self, **kw) for c in elements)
2802 if s
2803 )
2804
2805 def _generate_delimited_and_list(self, clauses, **kw):
2806 lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean(
2807 operators.and_,
2808 elements.True_._singleton,
2809 elements.False_._singleton,
2810 clauses,
2811 )
2812 if lcc == 1:
2813 return clauses[0]._compiler_dispatch(self, **kw)
2814 else:
2815 separator = OPERATORS[operators.and_]
2816 return separator.join(
2817 s
2818 for s in (c._compiler_dispatch(self, **kw) for c in clauses)
2819 if s
2820 )
2821
2822 def visit_tuple(self, clauselist, **kw):
2823 return "(%s)" % self.visit_clauselist(clauselist, **kw)
2824
2825 def visit_element_list(self, element, **kw):
2826 return self._generate_delimited_list(element.clauses, " ", **kw)
2827
2828 def visit_order_by_list(self, element, **kw):
2829 return self._generate_delimited_list(element.clauses, ", ", **kw)
2830
2831 def visit_clauselist(self, clauselist, **kw):
2832 sep = clauselist.operator
2833 if sep is None:
2834 sep = " "
2835 else:
2836 sep = OPERATORS[clauselist.operator]
2837
2838 return self._generate_delimited_list(clauselist.clauses, sep, **kw)
2839
2840 def visit_expression_clauselist(self, clauselist, **kw):
2841 operator_ = clauselist.operator
2842
2843 disp = self._get_operator_dispatch(
2844 operator_, "expression_clauselist", None
2845 )
2846 if disp:
2847 return disp(clauselist, operator_, **kw)
2848
2849 try:
2850 opstring = OPERATORS[operator_]
2851 except KeyError as err:
2852 raise exc.UnsupportedCompilationError(self, operator_) from err
2853 else:
2854 kw["_in_operator_expression"] = True
2855 return self._generate_delimited_list(
2856 clauselist.clauses, opstring, **kw
2857 )
2858
2859 def visit_case(self, clause, **kwargs):
2860 x = "CASE "
2861 if clause.value is not None:
2862 x += clause.value._compiler_dispatch(self, **kwargs) + " "
2863 for cond, result in clause.whens:
2864 x += (
2865 "WHEN "
2866 + cond._compiler_dispatch(self, **kwargs)
2867 + " THEN "
2868 + result._compiler_dispatch(self, **kwargs)
2869 + " "
2870 )
2871 if clause.else_ is not None:
2872 x += (
2873 "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
2874 )
2875 x += "END"
2876 return x
2877
2878 def visit_type_coerce(self, type_coerce, **kw):
2879 return type_coerce.typed_expression._compiler_dispatch(self, **kw)
2880
2881 def visit_cast(self, cast, **kwargs):
2882 type_clause = cast.typeclause._compiler_dispatch(self, **kwargs)
2883 match = re.match("(.*)( COLLATE .*)", type_clause)
2884 return "CAST(%s AS %s)%s" % (
2885 cast.clause._compiler_dispatch(self, **kwargs),
2886 match.group(1) if match else type_clause,
2887 match.group(2) if match else "",
2888 )
2889
2890 def visit_frame_clause(self, frameclause, **kw):
2891
2892 if frameclause.lower_type is elements.FrameClauseType.UNBOUNDED:
2893 left = "UNBOUNDED PRECEDING"
2894 elif frameclause.lower_type is elements.FrameClauseType.CURRENT:
2895 left = "CURRENT ROW"
2896 else:
2897 val = self.process(frameclause.lower_bind, **kw)
2898 if frameclause.lower_type is elements.FrameClauseType.PRECEDING:
2899 left = f"{val} PRECEDING"
2900 else:
2901 left = f"{val} FOLLOWING"
2902
2903 if frameclause.upper_type is elements.FrameClauseType.UNBOUNDED:
2904 right = "UNBOUNDED FOLLOWING"
2905 elif frameclause.upper_type is elements.FrameClauseType.CURRENT:
2906 right = "CURRENT ROW"
2907 else:
2908 val = self.process(frameclause.upper_bind, **kw)
2909 if frameclause.upper_type is elements.FrameClauseType.PRECEDING:
2910 right = f"{val} PRECEDING"
2911 else:
2912 right = f"{val} FOLLOWING"
2913
2914 return f"{left} AND {right}"
2915
2916 def visit_over(self, over, **kwargs):
2917 text = over.element._compiler_dispatch(self, **kwargs)
2918 if over.range_ is not None:
2919 range_ = f"RANGE BETWEEN {self.process(over.range_, **kwargs)}"
2920 elif over.rows is not None:
2921 range_ = f"ROWS BETWEEN {self.process(over.rows, **kwargs)}"
2922 elif over.groups is not None:
2923 range_ = f"GROUPS BETWEEN {self.process(over.groups, **kwargs)}"
2924 else:
2925 range_ = None
2926
2927 return "%s OVER (%s)" % (
2928 text,
2929 " ".join(
2930 [
2931 "%s BY %s"
2932 % (word, clause._compiler_dispatch(self, **kwargs))
2933 for word, clause in (
2934 ("PARTITION", over.partition_by),
2935 ("ORDER", over.order_by),
2936 )
2937 if clause is not None and len(clause)
2938 ]
2939 + ([range_] if range_ else [])
2940 ),
2941 )
2942
2943 def visit_withingroup(self, withingroup, **kwargs):
2944 return "%s WITHIN GROUP (ORDER BY %s)" % (
2945 withingroup.element._compiler_dispatch(self, **kwargs),
2946 withingroup.order_by._compiler_dispatch(self, **kwargs),
2947 )
2948
2949 def visit_funcfilter(self, funcfilter, **kwargs):
2950 return "%s FILTER (WHERE %s)" % (
2951 funcfilter.func._compiler_dispatch(self, **kwargs),
2952 funcfilter.criterion._compiler_dispatch(self, **kwargs),
2953 )
2954
2955 def visit_aggregateorderby(self, aggregateorderby, **kwargs):
2956 if self.dialect.aggregate_order_by_style is AggregateOrderByStyle.NONE:
2957 raise exc.CompileError(
2958 "this dialect does not support "
2959 "ORDER BY within an aggregate function"
2960 )
2961 elif (
2962 self.dialect.aggregate_order_by_style
2963 is AggregateOrderByStyle.INLINE
2964 ):
2965 new_fn = aggregateorderby.element._clone()
2966 new_fn.clause_expr = elements.Grouping(
2967 aggregate_orderby_inline(
2968 new_fn.clause_expr.element, aggregateorderby.order_by
2969 )
2970 )
2971
2972 return new_fn._compiler_dispatch(self, **kwargs)
2973 else:
2974 return self.visit_withingroup(aggregateorderby, **kwargs)
2975
2976 def visit_aggregate_orderby_inline(self, element, **kw):
2977 return "%s ORDER BY %s" % (
2978 self.process(element.element, **kw),
2979 self.process(element.aggregate_order_by, **kw),
2980 )
2981
2982 def visit_aggregate_strings_func(self, fn, *, use_function_name, **kw):
2983 # aggreagate_order_by attribute is present if visit_function
2984 # gave us a Function with aggregate_orderby_inline() as the inner
2985 # contents
2986 order_by = getattr(fn.clauses, "aggregate_order_by", None)
2987
2988 literal_exec = dict(kw)
2989 literal_exec["literal_execute"] = True
2990
2991 # break up the function into its components so we can apply
2992 # literal_execute to the second argument (the delimeter)
2993 cl = list(fn.clauses)
2994 expr, delimeter = cl[0:2]
2995 if (
2996 order_by is not None
2997 and self.dialect.aggregate_order_by_style
2998 is AggregateOrderByStyle.INLINE
2999 ):
3000 return (
3001 f"{use_function_name}({expr._compiler_dispatch(self, **kw)}, "
3002 f"{delimeter._compiler_dispatch(self, **literal_exec)} "
3003 f"ORDER BY {order_by._compiler_dispatch(self, **kw)})"
3004 )
3005 else:
3006 return (
3007 f"{use_function_name}({expr._compiler_dispatch(self, **kw)}, "
3008 f"{delimeter._compiler_dispatch(self, **literal_exec)})"
3009 )
3010
3011 def visit_extract(self, extract, **kwargs):
3012 field = self.extract_map.get(extract.field, extract.field)
3013 return "EXTRACT(%s FROM %s)" % (
3014 field,
3015 extract.expr._compiler_dispatch(self, **kwargs),
3016 )
3017
3018 def visit_scalar_function_column(self, element, **kw):
3019 compiled_fn = self.visit_function(element.fn, **kw)
3020 compiled_col = self.visit_column(element, **kw)
3021 return "(%s).%s" % (compiled_fn, compiled_col)
3022
3023 def visit_function(
3024 self,
3025 func: Function[Any],
3026 add_to_result_map: Optional[_ResultMapAppender] = None,
3027 **kwargs: Any,
3028 ) -> str:
3029 if add_to_result_map is not None:
3030 add_to_result_map(func.name, func.name, (func.name,), func.type)
3031
3032 disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
3033
3034 text: str
3035
3036 if disp:
3037 text = disp(func, **kwargs)
3038 else:
3039 name = FUNCTIONS.get(func._deannotate().__class__, None)
3040 if name:
3041 if func._has_args:
3042 name += "%(expr)s"
3043 else:
3044 name = func.name
3045 name = (
3046 self.preparer.quote(name)
3047 if self.preparer._requires_quotes_illegal_chars(name)
3048 or isinstance(name, elements.quoted_name)
3049 else name
3050 )
3051 name = name + "%(expr)s"
3052 text = ".".join(
3053 [
3054 (
3055 self.preparer.quote(tok)
3056 if self.preparer._requires_quotes_illegal_chars(tok)
3057 or isinstance(name, elements.quoted_name)
3058 else tok
3059 )
3060 for tok in func.packagenames
3061 ]
3062 + [name]
3063 ) % {"expr": self.function_argspec(func, **kwargs)}
3064
3065 if func._with_ordinality:
3066 text += " WITH ORDINALITY"
3067 return text
3068
3069 def visit_next_value_func(self, next_value, **kw):
3070 return self.visit_sequence(next_value.sequence)
3071
3072 def visit_sequence(self, sequence, **kw):
3073 raise NotImplementedError(
3074 "Dialect '%s' does not support sequence increments."
3075 % self.dialect.name
3076 )
3077
3078 def function_argspec(self, func: Function[Any], **kwargs: Any) -> str:
3079 return func.clause_expr._compiler_dispatch(self, **kwargs)
3080
3081 def visit_compound_select(
3082 self, cs, asfrom=False, compound_index=None, **kwargs
3083 ):
3084 toplevel = not self.stack
3085
3086 compile_state = cs._compile_state_factory(cs, self, **kwargs)
3087
3088 if toplevel and not self.compile_state:
3089 self.compile_state = compile_state
3090
3091 compound_stmt = compile_state.statement
3092
3093 entry = self._default_stack_entry if toplevel else self.stack[-1]
3094 need_result_map = toplevel or (
3095 not compound_index
3096 and entry.get("need_result_map_for_compound", False)
3097 )
3098
3099 # indicates there is already a CompoundSelect in play
3100 if compound_index == 0:
3101 entry["select_0"] = cs
3102
3103 self.stack.append(
3104 {
3105 "correlate_froms": entry["correlate_froms"],
3106 "asfrom_froms": entry["asfrom_froms"],
3107 "selectable": cs,
3108 "compile_state": compile_state,
3109 "need_result_map_for_compound": need_result_map,
3110 }
3111 )
3112
3113 if compound_stmt._independent_ctes:
3114 self._dispatch_independent_ctes(compound_stmt, kwargs)
3115
3116 keyword = self.compound_keywords[cs.keyword]
3117
3118 text = (" " + keyword + " ").join(
3119 (
3120 c._compiler_dispatch(
3121 self, asfrom=asfrom, compound_index=i, **kwargs
3122 )
3123 for i, c in enumerate(cs.selects)
3124 )
3125 )
3126
3127 kwargs["include_table"] = False
3128 text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
3129 text += self.order_by_clause(cs, **kwargs)
3130 if cs._has_row_limiting_clause:
3131 text += self._row_limit_clause(cs, **kwargs)
3132
3133 if self.ctes:
3134 nesting_level = len(self.stack) if not toplevel else None
3135 text = (
3136 self._render_cte_clause(
3137 nesting_level=nesting_level,
3138 include_following_stack=True,
3139 )
3140 + text
3141 )
3142
3143 self.stack.pop(-1)
3144 return text
3145
3146 def _row_limit_clause(self, cs, **kwargs):
3147 if cs._fetch_clause is not None:
3148 return self.fetch_clause(cs, **kwargs)
3149 else:
3150 return self.limit_clause(cs, **kwargs)
3151
3152 def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
3153 attrname = "visit_%s_%s%s" % (
3154 operator_.__name__,
3155 qualifier1,
3156 "_" + qualifier2 if qualifier2 else "",
3157 )
3158 return getattr(self, attrname, None)
3159
3160 def _get_custom_operator_dispatch(self, operator_, qualifier1):
3161 attrname = "visit_%s_op_%s" % (operator_.visit_name, qualifier1)
3162 return getattr(self, attrname, None)
3163
3164 def visit_unary(
3165 self, unary, add_to_result_map=None, result_map_targets=(), **kw
3166 ):
3167 if add_to_result_map is not None:
3168 result_map_targets += (unary,)
3169 kw["add_to_result_map"] = add_to_result_map
3170 kw["result_map_targets"] = result_map_targets
3171
3172 if unary.operator:
3173 if unary.modifier:
3174 raise exc.CompileError(
3175 "Unary expression does not support operator "
3176 "and modifier simultaneously"
3177 )
3178 disp = self._get_operator_dispatch(
3179 unary.operator, "unary", "operator"
3180 )
3181 if disp:
3182 return disp(unary, unary.operator, **kw)
3183 else:
3184 return self._generate_generic_unary_operator(
3185 unary, OPERATORS[unary.operator], **kw
3186 )
3187 elif unary.modifier:
3188 disp = self._get_operator_dispatch(
3189 unary.modifier, "unary", "modifier"
3190 )
3191 if disp:
3192 return disp(unary, unary.modifier, **kw)
3193 else:
3194 return self._generate_generic_unary_modifier(
3195 unary, OPERATORS[unary.modifier], **kw
3196 )
3197 else:
3198 raise exc.CompileError(
3199 "Unary expression has no operator or modifier"
3200 )
3201
3202 def visit_truediv_binary(self, binary, operator, **kw):
3203 if self.dialect.div_is_floordiv:
3204 return (
3205 self.process(binary.left, **kw)
3206 + " / "
3207 # TODO: would need a fast cast again here,
3208 # unless we want to use an implicit cast like "+ 0.0"
3209 + self.process(
3210 elements.Cast(
3211 binary.right,
3212 (
3213 binary.right.type
3214 if binary.right.type._type_affinity
3215 in (sqltypes.Numeric, sqltypes.Float)
3216 else sqltypes.Numeric()
3217 ),
3218 ),
3219 **kw,
3220 )
3221 )
3222 else:
3223 return (
3224 self.process(binary.left, **kw)
3225 + " / "
3226 + self.process(binary.right, **kw)
3227 )
3228
3229 def visit_floordiv_binary(self, binary, operator, **kw):
3230 if (
3231 self.dialect.div_is_floordiv
3232 and binary.right.type._type_affinity is sqltypes.Integer
3233 ):
3234 return (
3235 self.process(binary.left, **kw)
3236 + " / "
3237 + self.process(binary.right, **kw)
3238 )
3239 else:
3240 return "FLOOR(%s)" % (
3241 self.process(binary.left, **kw)
3242 + " / "
3243 + self.process(binary.right, **kw)
3244 )
3245
3246 def visit_is_true_unary_operator(self, element, operator, **kw):
3247 if (
3248 element._is_implicitly_boolean
3249 or self.dialect.supports_native_boolean
3250 ):
3251 return self.process(element.element, **kw)
3252 else:
3253 return "%s = 1" % self.process(element.element, **kw)
3254
3255 def visit_is_false_unary_operator(self, element, operator, **kw):
3256 if (
3257 element._is_implicitly_boolean
3258 or self.dialect.supports_native_boolean
3259 ):
3260 return "NOT %s" % self.process(element.element, **kw)
3261 else:
3262 return "%s = 0" % self.process(element.element, **kw)
3263
3264 def visit_not_match_op_binary(self, binary, operator, **kw):
3265 return "NOT %s" % self.visit_binary(
3266 binary, override_operator=operators.match_op
3267 )
3268
3269 def visit_not_in_op_binary(self, binary, operator, **kw):
3270 # The brackets are required in the NOT IN operation because the empty
3271 # case is handled using the form "(col NOT IN (null) OR 1 = 1)".
3272 # The presence of the OR makes the brackets required.
3273 return "(%s)" % self._generate_generic_binary(
3274 binary, OPERATORS[operator], **kw
3275 )
3276
3277 def visit_empty_set_op_expr(self, type_, expand_op, **kw):
3278 if expand_op is operators.not_in_op:
3279 if len(type_) > 1:
3280 return "(%s)) OR (1 = 1" % (
3281 ", ".join("NULL" for element in type_)
3282 )
3283 else:
3284 return "NULL) OR (1 = 1"
3285 elif expand_op is operators.in_op:
3286 if len(type_) > 1:
3287 return "(%s)) AND (1 != 1" % (
3288 ", ".join("NULL" for element in type_)
3289 )
3290 else:
3291 return "NULL) AND (1 != 1"
3292 else:
3293 return self.visit_empty_set_expr(type_)
3294
3295 def visit_empty_set_expr(self, element_types, **kw):
3296 raise NotImplementedError(
3297 "Dialect '%s' does not support empty set expression."
3298 % self.dialect.name
3299 )
3300
3301 def _literal_execute_expanding_parameter_literal_binds(
3302 self, parameter, values, bind_expression_template=None
3303 ):
3304 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
3305
3306 if not values:
3307 # empty IN expression. note we don't need to use
3308 # bind_expression_template here because there are no
3309 # expressions to render.
3310
3311 if typ_dialect_impl._is_tuple_type:
3312 replacement_expression = (
3313 "VALUES " if self.dialect.tuple_in_values else ""
3314 ) + self.visit_empty_set_op_expr(
3315 parameter.type.types, parameter.expand_op
3316 )
3317
3318 else:
3319 replacement_expression = self.visit_empty_set_op_expr(
3320 [parameter.type], parameter.expand_op
3321 )
3322
3323 elif typ_dialect_impl._is_tuple_type or (
3324 typ_dialect_impl._isnull
3325 and isinstance(values[0], collections_abc.Sequence)
3326 and not isinstance(values[0], (str, bytes))
3327 ):
3328 if typ_dialect_impl._has_bind_expression:
3329 raise NotImplementedError(
3330 "bind_expression() on TupleType not supported with "
3331 "literal_binds"
3332 )
3333
3334 replacement_expression = (
3335 "VALUES " if self.dialect.tuple_in_values else ""
3336 ) + ", ".join(
3337 "(%s)"
3338 % (
3339 ", ".join(
3340 self.render_literal_value(value, param_type)
3341 for value, param_type in zip(
3342 tuple_element, parameter.type.types
3343 )
3344 )
3345 )
3346 for i, tuple_element in enumerate(values)
3347 )
3348 else:
3349 if bind_expression_template:
3350 post_compile_pattern = self._post_compile_pattern
3351 m = post_compile_pattern.search(bind_expression_template)
3352 assert m and m.group(
3353 2
3354 ), "unexpected format for expanding parameter"
3355
3356 tok = m.group(2).split("~~")
3357 be_left, be_right = tok[1], tok[3]
3358 replacement_expression = ", ".join(
3359 "%s%s%s"
3360 % (
3361 be_left,
3362 self.render_literal_value(value, parameter.type),
3363 be_right,
3364 )
3365 for value in values
3366 )
3367 else:
3368 replacement_expression = ", ".join(
3369 self.render_literal_value(value, parameter.type)
3370 for value in values
3371 )
3372
3373 return (), replacement_expression
3374
3375 def _literal_execute_expanding_parameter(self, name, parameter, values):
3376 if parameter.literal_execute:
3377 return self._literal_execute_expanding_parameter_literal_binds(
3378 parameter, values
3379 )
3380
3381 dialect = self.dialect
3382 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect)
3383
3384 if self._numeric_binds:
3385 bind_template = self.compilation_bindtemplate
3386 else:
3387 bind_template = self.bindtemplate
3388
3389 if (
3390 self.dialect._bind_typing_render_casts
3391 and typ_dialect_impl.render_bind_cast
3392 ):
3393
3394 def _render_bindtemplate(name):
3395 return self.render_bind_cast(
3396 parameter.type,
3397 typ_dialect_impl,
3398 bind_template % {"name": name},
3399 )
3400
3401 else:
3402
3403 def _render_bindtemplate(name):
3404 return bind_template % {"name": name}
3405
3406 if not values:
3407 to_update = []
3408 if typ_dialect_impl._is_tuple_type:
3409 replacement_expression = self.visit_empty_set_op_expr(
3410 parameter.type.types, parameter.expand_op
3411 )
3412 else:
3413 replacement_expression = self.visit_empty_set_op_expr(
3414 [parameter.type], parameter.expand_op
3415 )
3416
3417 elif typ_dialect_impl._is_tuple_type or (
3418 typ_dialect_impl._isnull
3419 and isinstance(values[0], collections_abc.Sequence)
3420 and not isinstance(values[0], (str, bytes))
3421 ):
3422 assert not typ_dialect_impl._is_array
3423 to_update = [
3424 ("%s_%s_%s" % (name, i, j), value)
3425 for i, tuple_element in enumerate(values, 1)
3426 for j, value in enumerate(tuple_element, 1)
3427 ]
3428
3429 replacement_expression = (
3430 "VALUES " if dialect.tuple_in_values else ""
3431 ) + ", ".join(
3432 "(%s)"
3433 % (
3434 ", ".join(
3435 _render_bindtemplate(
3436 to_update[i * len(tuple_element) + j][0]
3437 )
3438 for j, value in enumerate(tuple_element)
3439 )
3440 )
3441 for i, tuple_element in enumerate(values)
3442 )
3443 else:
3444 to_update = [
3445 ("%s_%s" % (name, i), value)
3446 for i, value in enumerate(values, 1)
3447 ]
3448 replacement_expression = ", ".join(
3449 _render_bindtemplate(key) for key, value in to_update
3450 )
3451
3452 return to_update, replacement_expression
3453
3454 def visit_binary(
3455 self,
3456 binary,
3457 override_operator=None,
3458 eager_grouping=False,
3459 from_linter=None,
3460 lateral_from_linter=None,
3461 **kw,
3462 ):
3463 if from_linter and operators.is_comparison(binary.operator):
3464 if lateral_from_linter is not None:
3465 enclosing_lateral = kw["enclosing_lateral"]
3466 lateral_from_linter.edges.update(
3467 itertools.product(
3468 _de_clone(
3469 binary.left._from_objects + [enclosing_lateral]
3470 ),
3471 _de_clone(
3472 binary.right._from_objects + [enclosing_lateral]
3473 ),
3474 )
3475 )
3476 else:
3477 from_linter.edges.update(
3478 itertools.product(
3479 _de_clone(binary.left._from_objects),
3480 _de_clone(binary.right._from_objects),
3481 )
3482 )
3483
3484 # don't allow "? = ?" to render
3485 if (
3486 self.ansi_bind_rules
3487 and isinstance(binary.left, elements.BindParameter)
3488 and isinstance(binary.right, elements.BindParameter)
3489 ):
3490 kw["literal_execute"] = True
3491
3492 operator_ = override_operator or binary.operator
3493 disp = self._get_operator_dispatch(operator_, "binary", None)
3494 if disp:
3495 return disp(binary, operator_, **kw)
3496 else:
3497 try:
3498 opstring = OPERATORS[operator_]
3499 except KeyError as err:
3500 raise exc.UnsupportedCompilationError(self, operator_) from err
3501 else:
3502 return self._generate_generic_binary(
3503 binary,
3504 opstring,
3505 from_linter=from_linter,
3506 lateral_from_linter=lateral_from_linter,
3507 **kw,
3508 )
3509
3510 def visit_function_as_comparison_op_binary(self, element, operator, **kw):
3511 return self.process(element.sql_function, **kw)
3512
3513 def visit_mod_binary(self, binary, operator, **kw):
3514 if self.preparer._double_percents:
3515 return (
3516 self.process(binary.left, **kw)
3517 + " %% "
3518 + self.process(binary.right, **kw)
3519 )
3520 else:
3521 return (
3522 self.process(binary.left, **kw)
3523 + " % "
3524 + self.process(binary.right, **kw)
3525 )
3526
3527 def visit_custom_op_binary(self, element, operator, **kw):
3528 if operator.visit_name:
3529 disp = self._get_custom_operator_dispatch(operator, "binary")
3530 if disp:
3531 return disp(element, operator, **kw)
3532
3533 kw["eager_grouping"] = operator.eager_grouping
3534 return self._generate_generic_binary(
3535 element,
3536 " " + self.escape_literal_column(operator.opstring) + " ",
3537 **kw,
3538 )
3539
3540 def visit_custom_op_unary_operator(self, element, operator, **kw):
3541 if operator.visit_name:
3542 disp = self._get_custom_operator_dispatch(operator, "unary")
3543 if disp:
3544 return disp(element, operator, **kw)
3545
3546 return self._generate_generic_unary_operator(
3547 element, self.escape_literal_column(operator.opstring) + " ", **kw
3548 )
3549
3550 def visit_custom_op_unary_modifier(self, element, operator, **kw):
3551 if operator.visit_name:
3552 disp = self._get_custom_operator_dispatch(operator, "unary")
3553 if disp:
3554 return disp(element, operator, **kw)
3555
3556 return self._generate_generic_unary_modifier(
3557 element, " " + self.escape_literal_column(operator.opstring), **kw
3558 )
3559
3560 def _generate_generic_binary(
3561 self,
3562 binary: BinaryExpression[Any],
3563 opstring: str,
3564 eager_grouping: bool = False,
3565 **kw: Any,
3566 ) -> str:
3567 _in_operator_expression = kw.get("_in_operator_expression", False)
3568
3569 kw["_in_operator_expression"] = True
3570 kw["_binary_op"] = binary.operator
3571 text = (
3572 binary.left._compiler_dispatch(
3573 self, eager_grouping=eager_grouping, **kw
3574 )
3575 + opstring
3576 + binary.right._compiler_dispatch(
3577 self, eager_grouping=eager_grouping, **kw
3578 )
3579 )
3580
3581 if _in_operator_expression and eager_grouping:
3582 text = "(%s)" % text
3583 return text
3584
3585 def _generate_generic_unary_operator(self, unary, opstring, **kw):
3586 return opstring + unary.element._compiler_dispatch(self, **kw)
3587
3588 def _generate_generic_unary_modifier(self, unary, opstring, **kw):
3589 return unary.element._compiler_dispatch(self, **kw) + opstring
3590
3591 @util.memoized_property
3592 def _like_percent_literal(self):
3593 return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
3594
3595 def visit_ilike_case_insensitive_operand(self, element, **kw):
3596 return f"lower({element.element._compiler_dispatch(self, **kw)})"
3597
3598 def visit_contains_op_binary(self, binary, operator, **kw):
3599 binary = binary._clone()
3600 percent = self._like_percent_literal
3601 binary.right = percent.concat(binary.right).concat(percent)
3602 return self.visit_like_op_binary(binary, operator, **kw)
3603
3604 def visit_not_contains_op_binary(self, binary, operator, **kw):
3605 binary = binary._clone()
3606 percent = self._like_percent_literal
3607 binary.right = percent.concat(binary.right).concat(percent)
3608 return self.visit_not_like_op_binary(binary, operator, **kw)
3609
3610 def visit_icontains_op_binary(self, binary, operator, **kw):
3611 binary = binary._clone()
3612 percent = self._like_percent_literal
3613 binary.left = ilike_case_insensitive(binary.left)
3614 binary.right = percent.concat(
3615 ilike_case_insensitive(binary.right)
3616 ).concat(percent)
3617 return self.visit_ilike_op_binary(binary, operator, **kw)
3618
3619 def visit_not_icontains_op_binary(self, binary, operator, **kw):
3620 binary = binary._clone()
3621 percent = self._like_percent_literal
3622 binary.left = ilike_case_insensitive(binary.left)
3623 binary.right = percent.concat(
3624 ilike_case_insensitive(binary.right)
3625 ).concat(percent)
3626 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3627
3628 def visit_startswith_op_binary(self, binary, operator, **kw):
3629 binary = binary._clone()
3630 percent = self._like_percent_literal
3631 binary.right = percent._rconcat(binary.right)
3632 return self.visit_like_op_binary(binary, operator, **kw)
3633
3634 def visit_not_startswith_op_binary(self, binary, operator, **kw):
3635 binary = binary._clone()
3636 percent = self._like_percent_literal
3637 binary.right = percent._rconcat(binary.right)
3638 return self.visit_not_like_op_binary(binary, operator, **kw)
3639
3640 def visit_istartswith_op_binary(self, binary, operator, **kw):
3641 binary = binary._clone()
3642 percent = self._like_percent_literal
3643 binary.left = ilike_case_insensitive(binary.left)
3644 binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
3645 return self.visit_ilike_op_binary(binary, operator, **kw)
3646
3647 def visit_not_istartswith_op_binary(self, binary, operator, **kw):
3648 binary = binary._clone()
3649 percent = self._like_percent_literal
3650 binary.left = ilike_case_insensitive(binary.left)
3651 binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
3652 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3653
3654 def visit_endswith_op_binary(self, binary, operator, **kw):
3655 binary = binary._clone()
3656 percent = self._like_percent_literal
3657 binary.right = percent.concat(binary.right)
3658 return self.visit_like_op_binary(binary, operator, **kw)
3659
3660 def visit_not_endswith_op_binary(self, binary, operator, **kw):
3661 binary = binary._clone()
3662 percent = self._like_percent_literal
3663 binary.right = percent.concat(binary.right)
3664 return self.visit_not_like_op_binary(binary, operator, **kw)
3665
3666 def visit_iendswith_op_binary(self, binary, operator, **kw):
3667 binary = binary._clone()
3668 percent = self._like_percent_literal
3669 binary.left = ilike_case_insensitive(binary.left)
3670 binary.right = percent.concat(ilike_case_insensitive(binary.right))
3671 return self.visit_ilike_op_binary(binary, operator, **kw)
3672
3673 def visit_not_iendswith_op_binary(self, binary, operator, **kw):
3674 binary = binary._clone()
3675 percent = self._like_percent_literal
3676 binary.left = ilike_case_insensitive(binary.left)
3677 binary.right = percent.concat(ilike_case_insensitive(binary.right))
3678 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3679
3680 def visit_like_op_binary(self, binary, operator, **kw):
3681 escape = binary.modifiers.get("escape", None)
3682
3683 return "%s LIKE %s" % (
3684 binary.left._compiler_dispatch(self, **kw),
3685 binary.right._compiler_dispatch(self, **kw),
3686 ) + (
3687 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
3688 if escape is not None
3689 else ""
3690 )
3691
3692 def visit_not_like_op_binary(self, binary, operator, **kw):
3693 escape = binary.modifiers.get("escape", None)
3694 return "%s NOT LIKE %s" % (
3695 binary.left._compiler_dispatch(self, **kw),
3696 binary.right._compiler_dispatch(self, **kw),
3697 ) + (
3698 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
3699 if escape is not None
3700 else ""
3701 )
3702
3703 def visit_ilike_op_binary(self, binary, operator, **kw):
3704 if operator is operators.ilike_op:
3705 binary = binary._clone()
3706 binary.left = ilike_case_insensitive(binary.left)
3707 binary.right = ilike_case_insensitive(binary.right)
3708 # else we assume ilower() has been applied
3709
3710 return self.visit_like_op_binary(binary, operator, **kw)
3711
3712 def visit_not_ilike_op_binary(self, binary, operator, **kw):
3713 if operator is operators.not_ilike_op:
3714 binary = binary._clone()
3715 binary.left = ilike_case_insensitive(binary.left)
3716 binary.right = ilike_case_insensitive(binary.right)
3717 # else we assume ilower() has been applied
3718
3719 return self.visit_not_like_op_binary(binary, operator, **kw)
3720
3721 def visit_between_op_binary(self, binary, operator, **kw):
3722 symmetric = binary.modifiers.get("symmetric", False)
3723 return self._generate_generic_binary(
3724 binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
3725 )
3726
3727 def visit_not_between_op_binary(self, binary, operator, **kw):
3728 symmetric = binary.modifiers.get("symmetric", False)
3729 return self._generate_generic_binary(
3730 binary,
3731 " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
3732 **kw,
3733 )
3734
3735 def visit_regexp_match_op_binary(
3736 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3737 ) -> str:
3738 raise exc.CompileError(
3739 "%s dialect does not support regular expressions"
3740 % self.dialect.name
3741 )
3742
3743 def visit_not_regexp_match_op_binary(
3744 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3745 ) -> str:
3746 raise exc.CompileError(
3747 "%s dialect does not support regular expressions"
3748 % self.dialect.name
3749 )
3750
3751 def visit_regexp_replace_op_binary(
3752 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3753 ) -> str:
3754 raise exc.CompileError(
3755 "%s dialect does not support regular expression replacements"
3756 % self.dialect.name
3757 )
3758
3759 def visit_dmltargetcopy(self, element, *, bindmarkers=None, **kw):
3760 if bindmarkers is None:
3761 raise exc.CompileError(
3762 "DML target objects may only be used with "
3763 "compiled INSERT or UPDATE statements"
3764 )
3765
3766 bindmarkers[element.column.key] = element
3767 return f"__BINDMARKER_~~{element.column.key}~~"
3768
3769 def visit_bindparam(
3770 self,
3771 bindparam,
3772 within_columns_clause=False,
3773 literal_binds=False,
3774 skip_bind_expression=False,
3775 literal_execute=False,
3776 render_postcompile=False,
3777 **kwargs,
3778 ):
3779
3780 if not skip_bind_expression:
3781 impl = bindparam.type.dialect_impl(self.dialect)
3782 if impl._has_bind_expression:
3783 bind_expression = impl.bind_expression(bindparam)
3784 wrapped = self.process(
3785 bind_expression,
3786 skip_bind_expression=True,
3787 within_columns_clause=within_columns_clause,
3788 literal_binds=literal_binds and not bindparam.expanding,
3789 literal_execute=literal_execute,
3790 render_postcompile=render_postcompile,
3791 **kwargs,
3792 )
3793 if bindparam.expanding:
3794 # for postcompile w/ expanding, move the "wrapped" part
3795 # of this into the inside
3796
3797 m = re.match(
3798 r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
3799 )
3800 assert m, "unexpected format for expanding parameter"
3801 wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
3802 m.group(2),
3803 m.group(1),
3804 m.group(3),
3805 )
3806
3807 if literal_binds:
3808 ret = self.render_literal_bindparam(
3809 bindparam,
3810 within_columns_clause=True,
3811 bind_expression_template=wrapped,
3812 **kwargs,
3813 )
3814 return f"({ret})"
3815
3816 return wrapped
3817
3818 if not literal_binds:
3819 literal_execute = (
3820 literal_execute
3821 or bindparam.literal_execute
3822 or (within_columns_clause and self.ansi_bind_rules)
3823 )
3824 post_compile = literal_execute or bindparam.expanding
3825 else:
3826 post_compile = False
3827
3828 if literal_binds:
3829 ret = self.render_literal_bindparam(
3830 bindparam, within_columns_clause=True, **kwargs
3831 )
3832 if bindparam.expanding:
3833 ret = f"({ret})"
3834 return ret
3835
3836 name = self._truncate_bindparam(bindparam)
3837
3838 if name in self.binds:
3839 existing = self.binds[name]
3840 if existing is not bindparam:
3841 if (
3842 (existing.unique or bindparam.unique)
3843 and not existing.proxy_set.intersection(
3844 bindparam.proxy_set
3845 )
3846 and not existing._cloned_set.intersection(
3847 bindparam._cloned_set
3848 )
3849 ):
3850 raise exc.CompileError(
3851 "Bind parameter '%s' conflicts with "
3852 "unique bind parameter of the same name" % name
3853 )
3854 elif existing.expanding != bindparam.expanding:
3855 raise exc.CompileError(
3856 "Can't reuse bound parameter name '%s' in both "
3857 "'expanding' (e.g. within an IN expression) and "
3858 "non-expanding contexts. If this parameter is to "
3859 "receive a list/array value, set 'expanding=True' on "
3860 "it for expressions that aren't IN, otherwise use "
3861 "a different parameter name." % (name,)
3862 )
3863 elif existing._is_crud or bindparam._is_crud:
3864 if existing._is_crud and bindparam._is_crud:
3865 # TODO: this condition is not well understood.
3866 # see tests in test/sql/test_update.py
3867 raise exc.CompileError(
3868 "Encountered unsupported case when compiling an "
3869 "INSERT or UPDATE statement. If this is a "
3870 "multi-table "
3871 "UPDATE statement, please provide string-named "
3872 "arguments to the "
3873 "values() method with distinct names; support for "
3874 "multi-table UPDATE statements that "
3875 "target multiple tables for UPDATE is very "
3876 "limited",
3877 )
3878 else:
3879 raise exc.CompileError(
3880 f"bindparam() name '{bindparam.key}' is reserved "
3881 "for automatic usage in the VALUES or SET "
3882 "clause of this "
3883 "insert/update statement. Please use a "
3884 "name other than column name when using "
3885 "bindparam() "
3886 "with insert() or update() (for example, "
3887 f"'b_{bindparam.key}')."
3888 )
3889
3890 self.binds[bindparam.key] = self.binds[name] = bindparam
3891
3892 # if we are given a cache key that we're going to match against,
3893 # relate the bindparam here to one that is most likely present
3894 # in the "extracted params" portion of the cache key. this is used
3895 # to set up a positional mapping that is used to determine the
3896 # correct parameters for a subsequent use of this compiled with
3897 # a different set of parameter values. here, we accommodate for
3898 # parameters that may have been cloned both before and after the cache
3899 # key was been generated.
3900 ckbm_tuple = self._cache_key_bind_match
3901
3902 if ckbm_tuple:
3903 ckbm, cksm = ckbm_tuple
3904 for bp in bindparam._cloned_set:
3905 if bp.key in cksm:
3906 cb = cksm[bp.key]
3907 ckbm[cb].append(bindparam)
3908
3909 if bindparam.isoutparam:
3910 self.has_out_parameters = True
3911
3912 if post_compile:
3913 if render_postcompile:
3914 self._render_postcompile = True
3915
3916 if literal_execute:
3917 self.literal_execute_params |= {bindparam}
3918 else:
3919 self.post_compile_params |= {bindparam}
3920
3921 ret = self.bindparam_string(
3922 name,
3923 post_compile=post_compile,
3924 expanding=bindparam.expanding,
3925 bindparam_type=bindparam.type,
3926 **kwargs,
3927 )
3928
3929 if bindparam.expanding:
3930 ret = f"({ret})"
3931
3932 return ret
3933
3934 def render_bind_cast(self, type_, dbapi_type, sqltext):
3935 raise NotImplementedError()
3936
3937 def render_literal_bindparam(
3938 self,
3939 bindparam,
3940 render_literal_value=NO_ARG,
3941 bind_expression_template=None,
3942 **kw,
3943 ):
3944 if render_literal_value is not NO_ARG:
3945 value = render_literal_value
3946 else:
3947 if bindparam.value is None and bindparam.callable is None:
3948 op = kw.get("_binary_op", None)
3949 if op and op not in (operators.is_, operators.is_not):
3950 util.warn_limited(
3951 "Bound parameter '%s' rendering literal NULL in a SQL "
3952 "expression; comparisons to NULL should not use "
3953 "operators outside of 'is' or 'is not'",
3954 (bindparam.key,),
3955 )
3956 return self.process(sqltypes.NULLTYPE, **kw)
3957 value = bindparam.effective_value
3958
3959 if bindparam.expanding:
3960 leep = self._literal_execute_expanding_parameter_literal_binds
3961 to_update, replacement_expr = leep(
3962 bindparam,
3963 value,
3964 bind_expression_template=bind_expression_template,
3965 )
3966 return replacement_expr
3967 else:
3968 return self.render_literal_value(value, bindparam.type)
3969
3970 def render_literal_value(
3971 self, value: Any, type_: sqltypes.TypeEngine[Any]
3972 ) -> str:
3973 """Render the value of a bind parameter as a quoted literal.
3974
3975 This is used for statement sections that do not accept bind parameters
3976 on the target driver/database.
3977
3978 This should be implemented by subclasses using the quoting services
3979 of the DBAPI.
3980
3981 """
3982
3983 if value is None and not type_.should_evaluate_none:
3984 # issue #10535 - handle NULL in the compiler without placing
3985 # this onto each type, except for "evaluate None" types
3986 # (e.g. JSON)
3987 return self.process(elements.Null._instance())
3988
3989 processor = type_._cached_literal_processor(self.dialect)
3990 if processor:
3991 try:
3992 return processor(value)
3993 except Exception as e:
3994 raise exc.CompileError(
3995 f"Could not render literal value "
3996 f'"{sql_util._repr_single_value(value)}" '
3997 f"with datatype "
3998 f"{type_}; see parent stack trace for "
3999 "more detail."
4000 ) from e
4001
4002 else:
4003 raise exc.CompileError(
4004 f"No literal value renderer is available for literal value "
4005 f'"{sql_util._repr_single_value(value)}" '
4006 f"with datatype {type_}"
4007 )
4008
4009 def _truncate_bindparam(self, bindparam):
4010 if bindparam in self.bind_names:
4011 return self.bind_names[bindparam]
4012
4013 bind_name = bindparam.key
4014 if isinstance(bind_name, elements._truncated_label):
4015 bind_name = self._truncated_identifier("bindparam", bind_name)
4016
4017 # add to bind_names for translation
4018 self.bind_names[bindparam] = bind_name
4019
4020 return bind_name
4021
4022 def _truncated_identifier(
4023 self, ident_class: str, name: _truncated_label
4024 ) -> str:
4025 if (ident_class, name) in self.truncated_names:
4026 return self.truncated_names[(ident_class, name)]
4027
4028 anonname = name.apply_map(self.anon_map)
4029
4030 if len(anonname) > self.label_length - 6:
4031 counter = self._truncated_counters.get(ident_class, 1)
4032 truncname = (
4033 anonname[0 : max(self.label_length - 6, 0)]
4034 + "_"
4035 + hex(counter)[2:]
4036 )
4037 self._truncated_counters[ident_class] = counter + 1
4038 else:
4039 truncname = anonname
4040 self.truncated_names[(ident_class, name)] = truncname
4041 return truncname
4042
4043 def _anonymize(self, name: str) -> str:
4044 return name % self.anon_map
4045
4046 def bindparam_string(
4047 self,
4048 name: str,
4049 post_compile: bool = False,
4050 expanding: bool = False,
4051 escaped_from: Optional[str] = None,
4052 bindparam_type: Optional[TypeEngine[Any]] = None,
4053 accumulate_bind_names: Optional[Set[str]] = None,
4054 visited_bindparam: Optional[List[str]] = None,
4055 **kw: Any,
4056 ) -> str:
4057 # TODO: accumulate_bind_names is passed by crud.py to gather
4058 # names on a per-value basis, visited_bindparam is passed by
4059 # visit_insert() to collect all parameters in the statement.
4060 # see if this gathering can be simplified somehow
4061 if accumulate_bind_names is not None:
4062 accumulate_bind_names.add(name)
4063 if visited_bindparam is not None:
4064 visited_bindparam.append(name)
4065
4066 if not escaped_from:
4067 if self._bind_translate_re.search(name):
4068 # not quite the translate use case as we want to
4069 # also get a quick boolean if we even found
4070 # unusual characters in the name
4071 new_name = self._bind_translate_re.sub(
4072 lambda m: self._bind_translate_chars[m.group(0)],
4073 name,
4074 )
4075 escaped_from = name
4076 name = new_name
4077
4078 if escaped_from:
4079 self.escaped_bind_names = self.escaped_bind_names.union(
4080 {escaped_from: name}
4081 )
4082 if post_compile:
4083 ret = "__[POSTCOMPILE_%s]" % name
4084 if expanding:
4085 # for expanding, bound parameters or literal values will be
4086 # rendered per item
4087 return ret
4088
4089 # otherwise, for non-expanding "literal execute", apply
4090 # bind casts as determined by the datatype
4091 if bindparam_type is not None:
4092 type_impl = bindparam_type._unwrapped_dialect_impl(
4093 self.dialect
4094 )
4095 if type_impl.render_literal_cast:
4096 ret = self.render_bind_cast(bindparam_type, type_impl, ret)
4097 return ret
4098 elif self.state is CompilerState.COMPILING:
4099 ret = self.compilation_bindtemplate % {"name": name}
4100 else:
4101 ret = self.bindtemplate % {"name": name}
4102
4103 if (
4104 bindparam_type is not None
4105 and self.dialect._bind_typing_render_casts
4106 ):
4107 type_impl = bindparam_type._unwrapped_dialect_impl(self.dialect)
4108 if type_impl.render_bind_cast:
4109 ret = self.render_bind_cast(bindparam_type, type_impl, ret)
4110
4111 return ret
4112
4113 def _dispatch_independent_ctes(self, stmt, kw):
4114 local_kw = kw.copy()
4115 local_kw.pop("cte_opts", None)
4116 for cte, opt in zip(
4117 stmt._independent_ctes, stmt._independent_ctes_opts
4118 ):
4119 cte._compiler_dispatch(self, cte_opts=opt, **local_kw)
4120
4121 def visit_cte(
4122 self,
4123 cte: CTE,
4124 asfrom: bool = False,
4125 ashint: bool = False,
4126 fromhints: Optional[_FromHintsType] = None,
4127 visiting_cte: Optional[CTE] = None,
4128 from_linter: Optional[FromLinter] = None,
4129 cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
4130 **kwargs: Any,
4131 ) -> Optional[str]:
4132 self_ctes = self._init_cte_state()
4133 assert self_ctes is self.ctes
4134
4135 kwargs["visiting_cte"] = cte
4136
4137 cte_name = cte.name
4138
4139 if isinstance(cte_name, elements._truncated_label):
4140 cte_name = self._truncated_identifier("alias", cte_name)
4141
4142 is_new_cte = True
4143 embedded_in_current_named_cte = False
4144
4145 _reference_cte = cte._get_reference_cte()
4146
4147 nesting = cte.nesting or cte_opts.nesting
4148
4149 # check for CTE already encountered
4150 if _reference_cte in self.level_name_by_cte:
4151 cte_level, _, existing_cte_opts = self.level_name_by_cte[
4152 _reference_cte
4153 ]
4154 assert _ == cte_name
4155
4156 cte_level_name = (cte_level, cte_name)
4157 existing_cte = self.ctes_by_level_name[cte_level_name]
4158
4159 # check if we are receiving it here with a specific
4160 # "nest_here" location; if so, move it to this location
4161
4162 if cte_opts.nesting:
4163 if existing_cte_opts.nesting:
4164 raise exc.CompileError(
4165 "CTE is stated as 'nest_here' in "
4166 "more than one location"
4167 )
4168
4169 old_level_name = (cte_level, cte_name)
4170 cte_level = len(self.stack) if nesting else 1
4171 cte_level_name = new_level_name = (cte_level, cte_name)
4172
4173 del self.ctes_by_level_name[old_level_name]
4174 self.ctes_by_level_name[new_level_name] = existing_cte
4175 self.level_name_by_cte[_reference_cte] = new_level_name + (
4176 cte_opts,
4177 )
4178
4179 else:
4180 cte_level = len(self.stack) if nesting else 1
4181 cte_level_name = (cte_level, cte_name)
4182
4183 if cte_level_name in self.ctes_by_level_name:
4184 existing_cte = self.ctes_by_level_name[cte_level_name]
4185 else:
4186 existing_cte = None
4187
4188 if existing_cte is not None:
4189 embedded_in_current_named_cte = visiting_cte is existing_cte
4190
4191 # we've generated a same-named CTE that we are enclosed in,
4192 # or this is the same CTE. just return the name.
4193 if cte is existing_cte._restates or cte is existing_cte:
4194 is_new_cte = False
4195 elif existing_cte is cte._restates:
4196 # we've generated a same-named CTE that is
4197 # enclosed in us - we take precedence, so
4198 # discard the text for the "inner".
4199 del self_ctes[existing_cte]
4200
4201 existing_cte_reference_cte = existing_cte._get_reference_cte()
4202
4203 assert existing_cte_reference_cte is _reference_cte
4204 assert existing_cte_reference_cte is existing_cte
4205
4206 del self.level_name_by_cte[existing_cte_reference_cte]
4207 else:
4208 if (
4209 # if the two CTEs have the same hash, which we expect
4210 # here means that one/both is an annotated of the other
4211 (hash(cte) == hash(existing_cte))
4212 # or...
4213 or (
4214 (
4215 # if they are clones, i.e. they came from the ORM
4216 # or some other visit method
4217 cte._is_clone_of is not None
4218 or existing_cte._is_clone_of is not None
4219 )
4220 # and are deep-copy identical
4221 and cte.compare(existing_cte)
4222 )
4223 ):
4224 # then consider these two CTEs the same
4225 is_new_cte = False
4226 else:
4227 # otherwise these are two CTEs that either will render
4228 # differently, or were indicated separately by the user,
4229 # with the same name
4230 raise exc.CompileError(
4231 "Multiple, unrelated CTEs found with "
4232 "the same name: %r" % cte_name
4233 )
4234
4235 if not asfrom and not is_new_cte:
4236 return None
4237
4238 if cte._cte_alias is not None:
4239 pre_alias_cte = cte._cte_alias
4240 cte_pre_alias_name = cte._cte_alias.name
4241 if isinstance(cte_pre_alias_name, elements._truncated_label):
4242 cte_pre_alias_name = self._truncated_identifier(
4243 "alias", cte_pre_alias_name
4244 )
4245 else:
4246 pre_alias_cte = cte
4247 cte_pre_alias_name = None
4248
4249 if is_new_cte:
4250 self.ctes_by_level_name[cte_level_name] = cte
4251 self.level_name_by_cte[_reference_cte] = cte_level_name + (
4252 cte_opts,
4253 )
4254
4255 if pre_alias_cte not in self.ctes:
4256 self.visit_cte(pre_alias_cte, **kwargs)
4257
4258 if not cte_pre_alias_name and cte not in self_ctes:
4259 if cte.recursive:
4260 self.ctes_recursive = True
4261 text = self.preparer.format_alias(cte, cte_name)
4262 if cte.recursive or cte.element.name_cte_columns:
4263 col_source = cte.element
4264
4265 # TODO: can we get at the .columns_plus_names collection
4266 # that is already (or will be?) generated for the SELECT
4267 # rather than calling twice?
4268 recur_cols = [
4269 # TODO: proxy_name is not technically safe,
4270 # see test_cte->
4271 # test_with_recursive_no_name_currently_buggy. not
4272 # clear what should be done with such a case
4273 fallback_label_name or proxy_name
4274 for (
4275 _,
4276 proxy_name,
4277 fallback_label_name,
4278 c,
4279 repeated,
4280 ) in (col_source._generate_columns_plus_names(True))
4281 if not repeated
4282 ]
4283
4284 text += "(%s)" % (
4285 ", ".join(
4286 self.preparer.format_label_name(
4287 ident, anon_map=self.anon_map
4288 )
4289 for ident in recur_cols
4290 )
4291 )
4292
4293 assert kwargs.get("subquery", False) is False
4294
4295 if not self.stack:
4296 # toplevel, this is a stringify of the
4297 # cte directly. just compile the inner
4298 # the way alias() does.
4299 return cte.element._compiler_dispatch(
4300 self, asfrom=asfrom, **kwargs
4301 )
4302 else:
4303 prefixes = self._generate_prefixes(
4304 cte, cte._prefixes, **kwargs
4305 )
4306 inner = cte.element._compiler_dispatch(
4307 self, asfrom=True, **kwargs
4308 )
4309
4310 text += " AS %s\n(%s)" % (prefixes, inner)
4311
4312 if cte._suffixes:
4313 text += " " + self._generate_prefixes(
4314 cte, cte._suffixes, **kwargs
4315 )
4316
4317 self_ctes[cte] = text
4318
4319 if asfrom:
4320 if from_linter:
4321 from_linter.froms[cte._de_clone()] = cte_name
4322
4323 if not is_new_cte and embedded_in_current_named_cte:
4324 return self.preparer.format_alias(cte, cte_name)
4325
4326 if cte_pre_alias_name:
4327 text = self.preparer.format_alias(cte, cte_pre_alias_name)
4328 if self.preparer._requires_quotes(cte_name):
4329 cte_name = self.preparer.quote(cte_name)
4330 text += self.get_render_as_alias_suffix(cte_name)
4331 return text # type: ignore[no-any-return]
4332 else:
4333 return self.preparer.format_alias(cte, cte_name)
4334
4335 return None
4336
4337 def visit_table_valued_alias(self, element, **kw):
4338 if element.joins_implicitly:
4339 kw["from_linter"] = None
4340 if element._is_lateral:
4341 return self.visit_lateral(element, **kw)
4342 else:
4343 return self.visit_alias(element, **kw)
4344
4345 def visit_table_valued_column(self, element, **kw):
4346 return self.visit_column(element, **kw)
4347
4348 def visit_alias(
4349 self,
4350 alias,
4351 asfrom=False,
4352 ashint=False,
4353 iscrud=False,
4354 fromhints=None,
4355 subquery=False,
4356 lateral=False,
4357 enclosing_alias=None,
4358 from_linter=None,
4359 **kwargs,
4360 ):
4361 if lateral:
4362 if "enclosing_lateral" not in kwargs:
4363 # if lateral is set and enclosing_lateral is not
4364 # present, we assume we are being called directly
4365 # from visit_lateral() and we need to set enclosing_lateral.
4366 assert alias._is_lateral
4367 kwargs["enclosing_lateral"] = alias
4368
4369 # for lateral objects, we track a second from_linter that is...
4370 # lateral! to the level above us.
4371 if (
4372 from_linter
4373 and "lateral_from_linter" not in kwargs
4374 and "enclosing_lateral" in kwargs
4375 ):
4376 kwargs["lateral_from_linter"] = from_linter
4377
4378 if enclosing_alias is not None and enclosing_alias.element is alias:
4379 inner = alias.element._compiler_dispatch(
4380 self,
4381 asfrom=asfrom,
4382 ashint=ashint,
4383 iscrud=iscrud,
4384 fromhints=fromhints,
4385 lateral=lateral,
4386 enclosing_alias=alias,
4387 **kwargs,
4388 )
4389 if subquery and (asfrom or lateral):
4390 inner = "(%s)" % (inner,)
4391 return inner
4392 else:
4393 kwargs["enclosing_alias"] = alias
4394
4395 if asfrom or ashint:
4396 if isinstance(alias.name, elements._truncated_label):
4397 alias_name = self._truncated_identifier("alias", alias.name)
4398 else:
4399 alias_name = alias.name
4400
4401 if ashint:
4402 return self.preparer.format_alias(alias, alias_name)
4403 elif asfrom:
4404 if from_linter:
4405 from_linter.froms[alias._de_clone()] = alias_name
4406
4407 inner = alias.element._compiler_dispatch(
4408 self, asfrom=True, lateral=lateral, **kwargs
4409 )
4410 if subquery:
4411 inner = "(%s)" % (inner,)
4412
4413 ret = inner + self.get_render_as_alias_suffix(
4414 self.preparer.format_alias(alias, alias_name)
4415 )
4416
4417 if alias._supports_derived_columns and alias._render_derived:
4418 ret += "(%s)" % (
4419 ", ".join(
4420 "%s%s"
4421 % (
4422 self.preparer.quote(col.name),
4423 (
4424 " %s"
4425 % self.dialect.type_compiler_instance.process(
4426 col.type, **kwargs
4427 )
4428 if alias._render_derived_w_types
4429 else ""
4430 ),
4431 )
4432 for col in alias.c
4433 )
4434 )
4435
4436 if fromhints and alias in fromhints:
4437 ret = self.format_from_hint_text(
4438 ret, alias, fromhints[alias], iscrud
4439 )
4440
4441 return ret
4442 else:
4443 # note we cancel the "subquery" flag here as well
4444 return alias.element._compiler_dispatch(
4445 self, lateral=lateral, **kwargs
4446 )
4447
4448 def visit_subquery(self, subquery, **kw):
4449 kw["subquery"] = True
4450 return self.visit_alias(subquery, **kw)
4451
4452 def visit_lateral(self, lateral_, **kw):
4453 kw["lateral"] = True
4454 return "LATERAL %s" % self.visit_alias(lateral_, **kw)
4455
4456 def visit_tablesample(self, tablesample, asfrom=False, **kw):
4457 text = "%s TABLESAMPLE %s" % (
4458 self.visit_alias(tablesample, asfrom=True, **kw),
4459 tablesample._get_method()._compiler_dispatch(self, **kw),
4460 )
4461
4462 if tablesample.seed is not None:
4463 text += " REPEATABLE (%s)" % (
4464 tablesample.seed._compiler_dispatch(self, **kw)
4465 )
4466
4467 return text
4468
4469 def _render_values(self, element, **kw):
4470 kw.setdefault("literal_binds", element.literal_binds)
4471 tuples = ", ".join(
4472 self.process(
4473 elements.Tuple(
4474 types=element._column_types, *elem
4475 ).self_group(),
4476 **kw,
4477 )
4478 for chunk in element._data
4479 for elem in chunk
4480 )
4481 return f"VALUES {tuples}"
4482
4483 def visit_values(
4484 self, element, asfrom=False, from_linter=None, visiting_cte=None, **kw
4485 ):
4486
4487 if element._independent_ctes:
4488 self._dispatch_independent_ctes(element, kw)
4489
4490 v = self._render_values(element, **kw)
4491
4492 if element._unnamed:
4493 name = None
4494 elif isinstance(element.name, elements._truncated_label):
4495 name = self._truncated_identifier("values", element.name)
4496 else:
4497 name = element.name
4498
4499 if element._is_lateral:
4500 lateral = "LATERAL "
4501 else:
4502 lateral = ""
4503
4504 if asfrom:
4505 if from_linter:
4506 from_linter.froms[element._de_clone()] = (
4507 name if name is not None else "(unnamed VALUES element)"
4508 )
4509
4510 if visiting_cte is not None and visiting_cte.element is element:
4511 if element._is_lateral:
4512 raise exc.CompileError(
4513 "Can't use a LATERAL VALUES expression inside of a CTE"
4514 )
4515 elif name:
4516 kw["include_table"] = False
4517 v = "%s(%s)%s (%s)" % (
4518 lateral,
4519 v,
4520 self.get_render_as_alias_suffix(self.preparer.quote(name)),
4521 (
4522 ", ".join(
4523 c._compiler_dispatch(self, **kw)
4524 for c in element.columns
4525 )
4526 ),
4527 )
4528 else:
4529 v = "%s(%s)" % (lateral, v)
4530 return v
4531
4532 def visit_scalar_values(self, element, **kw):
4533 return f"({self._render_values(element, **kw)})"
4534
4535 def get_render_as_alias_suffix(self, alias_name_text):
4536 return " AS " + alias_name_text
4537
4538 def _add_to_result_map(
4539 self,
4540 keyname: str,
4541 name: str,
4542 objects: Tuple[Any, ...],
4543 type_: TypeEngine[Any],
4544 ) -> None:
4545
4546 # note objects must be non-empty for cursor.py to handle the
4547 # collection properly
4548 assert objects
4549
4550 if keyname is None or keyname == "*":
4551 self._ordered_columns = False
4552 self._ad_hoc_textual = True
4553 if type_._is_tuple_type:
4554 raise exc.CompileError(
4555 "Most backends don't support SELECTing "
4556 "from a tuple() object. If this is an ORM query, "
4557 "consider using the Bundle object."
4558 )
4559 self._result_columns.append(
4560 ResultColumnsEntry(keyname, name, objects, type_)
4561 )
4562
4563 def _label_returning_column(
4564 self, stmt, column, populate_result_map, column_clause_args=None, **kw
4565 ):
4566 """Render a column with necessary labels inside of a RETURNING clause.
4567
4568 This method is provided for individual dialects in place of calling
4569 the _label_select_column method directly, so that the two use cases
4570 of RETURNING vs. SELECT can be disambiguated going forward.
4571
4572 .. versionadded:: 1.4.21
4573
4574 """
4575 return self._label_select_column(
4576 None,
4577 column,
4578 populate_result_map,
4579 False,
4580 {} if column_clause_args is None else column_clause_args,
4581 **kw,
4582 )
4583
4584 def _label_select_column(
4585 self,
4586 select,
4587 column,
4588 populate_result_map,
4589 asfrom,
4590 column_clause_args,
4591 name=None,
4592 proxy_name=None,
4593 fallback_label_name=None,
4594 within_columns_clause=True,
4595 column_is_repeated=False,
4596 need_column_expressions=False,
4597 include_table=True,
4598 ):
4599 """produce labeled columns present in a select()."""
4600 impl = column.type.dialect_impl(self.dialect)
4601
4602 if impl._has_column_expression and (
4603 need_column_expressions or populate_result_map
4604 ):
4605 col_expr = impl.column_expression(column)
4606 else:
4607 col_expr = column
4608
4609 if populate_result_map:
4610 # pass an "add_to_result_map" callable into the compilation
4611 # of embedded columns. this collects information about the
4612 # column as it will be fetched in the result and is coordinated
4613 # with cursor.description when the query is executed.
4614 add_to_result_map = self._add_to_result_map
4615
4616 # if the SELECT statement told us this column is a repeat,
4617 # wrap the callable with one that prevents the addition of the
4618 # targets
4619 if column_is_repeated:
4620 _add_to_result_map = add_to_result_map
4621
4622 def add_to_result_map(keyname, name, objects, type_):
4623 _add_to_result_map(keyname, name, (keyname,), type_)
4624
4625 # if we redefined col_expr for type expressions, wrap the
4626 # callable with one that adds the original column to the targets
4627 elif col_expr is not column:
4628 _add_to_result_map = add_to_result_map
4629
4630 def add_to_result_map(keyname, name, objects, type_):
4631 _add_to_result_map(
4632 keyname, name, (column,) + objects, type_
4633 )
4634
4635 else:
4636 add_to_result_map = None
4637
4638 # this method is used by some of the dialects for RETURNING,
4639 # which has different inputs. _label_returning_column was added
4640 # as the better target for this now however for 1.4 we will keep
4641 # _label_select_column directly compatible with this use case.
4642 # these assertions right now set up the current expected inputs
4643 assert within_columns_clause, (
4644 "_label_select_column is only relevant within "
4645 "the columns clause of a SELECT or RETURNING"
4646 )
4647 if isinstance(column, elements.Label):
4648 if col_expr is not column:
4649 result_expr = _CompileLabel(
4650 col_expr, column.name, alt_names=(column.element,)
4651 )
4652 else:
4653 result_expr = col_expr
4654
4655 elif name:
4656 # here, _columns_plus_names has determined there's an explicit
4657 # label name we need to use. this is the default for
4658 # tablenames_plus_columnnames as well as when columns are being
4659 # deduplicated on name
4660
4661 assert (
4662 proxy_name is not None
4663 ), "proxy_name is required if 'name' is passed"
4664
4665 result_expr = _CompileLabel(
4666 col_expr,
4667 name,
4668 alt_names=(
4669 proxy_name,
4670 # this is a hack to allow legacy result column lookups
4671 # to work as they did before; this goes away in 2.0.
4672 # TODO: this only seems to be tested indirectly
4673 # via test/orm/test_deprecations.py. should be a
4674 # resultset test for this
4675 column._tq_label,
4676 ),
4677 )
4678 else:
4679 # determine here whether this column should be rendered in
4680 # a labelled context or not, as we were given no required label
4681 # name from the caller. Here we apply heuristics based on the kind
4682 # of SQL expression involved.
4683
4684 if col_expr is not column:
4685 # type-specific expression wrapping the given column,
4686 # so we render a label
4687 render_with_label = True
4688 elif isinstance(column, elements.ColumnClause):
4689 # table-bound column, we render its name as a label if we are
4690 # inside of a subquery only
4691 render_with_label = (
4692 asfrom
4693 and not column.is_literal
4694 and column.table is not None
4695 )
4696 elif isinstance(column, elements.TextClause):
4697 render_with_label = False
4698 elif isinstance(column, elements.UnaryExpression):
4699 # unary expression. notes added as of #12681
4700 #
4701 # By convention, the visit_unary() method
4702 # itself does not add an entry to the result map, and relies
4703 # upon either the inner expression creating a result map
4704 # entry, or if not, by creating a label here that produces
4705 # the result map entry. Where that happens is based on whether
4706 # or not the element immediately inside the unary is a
4707 # NamedColumn subclass or not.
4708 #
4709 # Now, this also impacts how the SELECT is written; if
4710 # we decide to generate a label here, we get the usual
4711 # "~(x+y) AS anon_1" thing in the columns clause. If we
4712 # don't, we don't get an AS at all, we get like
4713 # "~table.column".
4714 #
4715 # But here is the important thing as of modernish (like 1.4)
4716 # versions of SQLAlchemy - **whether or not the AS <label>
4717 # is present in the statement is not actually important**.
4718 # We target result columns **positionally** for a fully
4719 # compiled ``Select()`` object; before 1.4 we needed those
4720 # labels to match in cursor.description etc etc but now it
4721 # really doesn't matter.
4722 # So really, we could set render_with_label True in all cases.
4723 # Or we could just have visit_unary() populate the result map
4724 # in all cases.
4725 #
4726 # What we're doing here is strictly trying to not rock the
4727 # boat too much with when we do/don't render "AS label";
4728 # labels being present helps in the edge cases that we
4729 # "fall back" to named cursor.description matching, labels
4730 # not being present for columns keeps us from having awkward
4731 # phrases like "SELECT DISTINCT table.x AS x".
4732 render_with_label = (
4733 (
4734 # exception case to detect if we render "not boolean"
4735 # as "not <col>" for native boolean or "<col> = 1"
4736 # for non-native boolean. this is controlled by
4737 # visit_is_<true|false>_unary_operator
4738 column.operator
4739 in (operators.is_false, operators.is_true)
4740 and not self.dialect.supports_native_boolean
4741 )
4742 or column._wraps_unnamed_column()
4743 or asfrom
4744 )
4745 elif (
4746 # general class of expressions that don't have a SQL-column
4747 # addressible name. includes scalar selects, bind parameters,
4748 # SQL functions, others
4749 not isinstance(column, elements.NamedColumn)
4750 # deeper check that indicates there's no natural "name" to
4751 # this element, which accommodates for custom SQL constructs
4752 # that might have a ".name" attribute (but aren't SQL
4753 # functions) but are not implementing this more recently added
4754 # base class. in theory the "NamedColumn" check should be
4755 # enough, however here we seek to maintain legacy behaviors
4756 # as well.
4757 and column._non_anon_label is None
4758 ):
4759 render_with_label = True
4760 else:
4761 render_with_label = False
4762
4763 if render_with_label:
4764 if not fallback_label_name:
4765 # used by the RETURNING case right now. we generate it
4766 # here as 3rd party dialects may be referring to
4767 # _label_select_column method directly instead of the
4768 # just-added _label_returning_column method
4769 assert not column_is_repeated
4770 fallback_label_name = column._anon_name_label
4771
4772 fallback_label_name = (
4773 elements._truncated_label(fallback_label_name)
4774 if not isinstance(
4775 fallback_label_name, elements._truncated_label
4776 )
4777 else fallback_label_name
4778 )
4779
4780 result_expr = _CompileLabel(
4781 col_expr, fallback_label_name, alt_names=(proxy_name,)
4782 )
4783 else:
4784 result_expr = col_expr
4785
4786 column_clause_args.update(
4787 within_columns_clause=within_columns_clause,
4788 add_to_result_map=add_to_result_map,
4789 include_table=include_table,
4790 )
4791 return result_expr._compiler_dispatch(self, **column_clause_args)
4792
4793 def format_from_hint_text(self, sqltext, table, hint, iscrud):
4794 hinttext = self.get_from_hint_text(table, hint)
4795 if hinttext:
4796 sqltext += " " + hinttext
4797 return sqltext
4798
4799 def get_select_hint_text(self, byfroms):
4800 return None
4801
4802 def get_from_hint_text(
4803 self, table: FromClause, text: Optional[str]
4804 ) -> Optional[str]:
4805 return None
4806
4807 def get_crud_hint_text(self, table, text):
4808 return None
4809
4810 def get_statement_hint_text(self, hint_texts):
4811 return " ".join(hint_texts)
4812
4813 _default_stack_entry: _CompilerStackEntry
4814
4815 if not typing.TYPE_CHECKING:
4816 _default_stack_entry = util.immutabledict(
4817 [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
4818 )
4819
4820 def _display_froms_for_select(
4821 self, select_stmt, asfrom, lateral=False, **kw
4822 ):
4823 # utility method to help external dialects
4824 # get the correct from list for a select.
4825 # specifically the oracle dialect needs this feature
4826 # right now.
4827 toplevel = not self.stack
4828 entry = self._default_stack_entry if toplevel else self.stack[-1]
4829
4830 compile_state = select_stmt._compile_state_factory(select_stmt, self)
4831
4832 correlate_froms = entry["correlate_froms"]
4833 asfrom_froms = entry["asfrom_froms"]
4834
4835 if asfrom and not lateral:
4836 froms = compile_state._get_display_froms(
4837 explicit_correlate_froms=correlate_froms.difference(
4838 asfrom_froms
4839 ),
4840 implicit_correlate_froms=(),
4841 )
4842 else:
4843 froms = compile_state._get_display_froms(
4844 explicit_correlate_froms=correlate_froms,
4845 implicit_correlate_froms=asfrom_froms,
4846 )
4847 return froms
4848
4849 translate_select_structure: Any = None
4850 """if not ``None``, should be a callable which accepts ``(select_stmt,
4851 **kw)`` and returns a select object. this is used for structural changes
4852 mostly to accommodate for LIMIT/OFFSET schemes
4853
4854 """
4855
4856 def visit_select(
4857 self,
4858 select_stmt,
4859 asfrom=False,
4860 insert_into=False,
4861 fromhints=None,
4862 compound_index=None,
4863 select_wraps_for=None,
4864 lateral=False,
4865 from_linter=None,
4866 **kwargs,
4867 ):
4868 assert select_wraps_for is None, (
4869 "SQLAlchemy 1.4 requires use of "
4870 "the translate_select_structure hook for structural "
4871 "translations of SELECT objects"
4872 )
4873
4874 # initial setup of SELECT. the compile_state_factory may now
4875 # be creating a totally different SELECT from the one that was
4876 # passed in. for ORM use this will convert from an ORM-state
4877 # SELECT to a regular "Core" SELECT. other composed operations
4878 # such as computation of joins will be performed.
4879
4880 kwargs["within_columns_clause"] = False
4881
4882 compile_state = select_stmt._compile_state_factory(
4883 select_stmt, self, **kwargs
4884 )
4885 kwargs["ambiguous_table_name_map"] = (
4886 compile_state._ambiguous_table_name_map
4887 )
4888
4889 select_stmt = compile_state.statement
4890
4891 toplevel = not self.stack
4892
4893 if toplevel and not self.compile_state:
4894 self.compile_state = compile_state
4895
4896 is_embedded_select = compound_index is not None or insert_into
4897
4898 # translate step for Oracle, SQL Server which often need to
4899 # restructure the SELECT to allow for LIMIT/OFFSET and possibly
4900 # other conditions
4901 if self.translate_select_structure:
4902 new_select_stmt = self.translate_select_structure(
4903 select_stmt, asfrom=asfrom, **kwargs
4904 )
4905
4906 # if SELECT was restructured, maintain a link to the originals
4907 # and assemble a new compile state
4908 if new_select_stmt is not select_stmt:
4909 compile_state_wraps_for = compile_state
4910 select_wraps_for = select_stmt
4911 select_stmt = new_select_stmt
4912
4913 compile_state = select_stmt._compile_state_factory(
4914 select_stmt, self, **kwargs
4915 )
4916 select_stmt = compile_state.statement
4917
4918 entry = self._default_stack_entry if toplevel else self.stack[-1]
4919
4920 populate_result_map = need_column_expressions = (
4921 toplevel
4922 or entry.get("need_result_map_for_compound", False)
4923 or entry.get("need_result_map_for_nested", False)
4924 )
4925
4926 # indicates there is a CompoundSelect in play and we are not the
4927 # first select
4928 if compound_index:
4929 populate_result_map = False
4930
4931 # this was first proposed as part of #3372; however, it is not
4932 # reached in current tests and could possibly be an assertion
4933 # instead.
4934 if not populate_result_map and "add_to_result_map" in kwargs:
4935 del kwargs["add_to_result_map"]
4936
4937 froms = self._setup_select_stack(
4938 select_stmt, compile_state, entry, asfrom, lateral, compound_index
4939 )
4940
4941 column_clause_args = kwargs.copy()
4942 column_clause_args.update(
4943 {"within_label_clause": False, "within_columns_clause": False}
4944 )
4945
4946 text = "SELECT " # we're off to a good start !
4947
4948 if select_stmt._post_select_clause is not None:
4949 psc = self.process(select_stmt._post_select_clause, **kwargs)
4950 if psc is not None:
4951 text += psc + " "
4952
4953 if select_stmt._hints:
4954 hint_text, byfrom = self._setup_select_hints(select_stmt)
4955 if hint_text:
4956 text += hint_text + " "
4957 else:
4958 byfrom = None
4959
4960 if select_stmt._independent_ctes:
4961 self._dispatch_independent_ctes(select_stmt, kwargs)
4962
4963 if select_stmt._prefixes:
4964 text += self._generate_prefixes(
4965 select_stmt, select_stmt._prefixes, **kwargs
4966 )
4967
4968 text += self.get_select_precolumns(select_stmt, **kwargs)
4969
4970 if select_stmt._pre_columns_clause is not None:
4971 pcc = self.process(select_stmt._pre_columns_clause, **kwargs)
4972 if pcc is not None:
4973 text += pcc + " "
4974
4975 # the actual list of columns to print in the SELECT column list.
4976 inner_columns = [
4977 c
4978 for c in [
4979 self._label_select_column(
4980 select_stmt,
4981 column,
4982 populate_result_map,
4983 asfrom,
4984 column_clause_args,
4985 name=name,
4986 proxy_name=proxy_name,
4987 fallback_label_name=fallback_label_name,
4988 column_is_repeated=repeated,
4989 need_column_expressions=need_column_expressions,
4990 )
4991 for (
4992 name,
4993 proxy_name,
4994 fallback_label_name,
4995 column,
4996 repeated,
4997 ) in compile_state.columns_plus_names
4998 ]
4999 if c is not None
5000 ]
5001
5002 if populate_result_map and select_wraps_for is not None:
5003 # if this select was generated from translate_select,
5004 # rewrite the targeted columns in the result map
5005
5006 translate = dict(
5007 zip(
5008 [
5009 name
5010 for (
5011 key,
5012 proxy_name,
5013 fallback_label_name,
5014 name,
5015 repeated,
5016 ) in compile_state.columns_plus_names
5017 ],
5018 [
5019 name
5020 for (
5021 key,
5022 proxy_name,
5023 fallback_label_name,
5024 name,
5025 repeated,
5026 ) in compile_state_wraps_for.columns_plus_names
5027 ],
5028 )
5029 )
5030
5031 self._result_columns = [
5032 ResultColumnsEntry(
5033 key, name, tuple(translate.get(o, o) for o in obj), type_
5034 )
5035 for key, name, obj, type_ in self._result_columns
5036 ]
5037
5038 text = self._compose_select_body(
5039 text,
5040 select_stmt,
5041 compile_state,
5042 inner_columns,
5043 froms,
5044 byfrom,
5045 toplevel,
5046 kwargs,
5047 )
5048
5049 if select_stmt._post_body_clause is not None:
5050 pbc = self.process(select_stmt._post_body_clause, **kwargs)
5051 if pbc:
5052 text += " " + pbc
5053
5054 if select_stmt._statement_hints:
5055 per_dialect = [
5056 ht
5057 for (dialect_name, ht) in select_stmt._statement_hints
5058 if dialect_name in ("*", self.dialect.name)
5059 ]
5060 if per_dialect:
5061 text += " " + self.get_statement_hint_text(per_dialect)
5062
5063 # In compound query, CTEs are shared at the compound level
5064 if self.ctes and (not is_embedded_select or toplevel):
5065 nesting_level = len(self.stack) if not toplevel else None
5066 text = self._render_cte_clause(nesting_level=nesting_level) + text
5067
5068 if select_stmt._suffixes:
5069 text += " " + self._generate_prefixes(
5070 select_stmt, select_stmt._suffixes, **kwargs
5071 )
5072
5073 self.stack.pop(-1)
5074
5075 return text
5076
5077 def _setup_select_hints(
5078 self, select: Select[Unpack[TupleAny]]
5079 ) -> Tuple[str, _FromHintsType]:
5080 byfrom = {
5081 from_: hinttext
5082 % {"name": from_._compiler_dispatch(self, ashint=True)}
5083 for (from_, dialect), hinttext in select._hints.items()
5084 if dialect in ("*", self.dialect.name)
5085 }
5086 hint_text = self.get_select_hint_text(byfrom)
5087 return hint_text, byfrom
5088
5089 def _setup_select_stack(
5090 self, select, compile_state, entry, asfrom, lateral, compound_index
5091 ):
5092 correlate_froms = entry["correlate_froms"]
5093 asfrom_froms = entry["asfrom_froms"]
5094
5095 if compound_index == 0:
5096 entry["select_0"] = select
5097 elif compound_index:
5098 select_0 = entry["select_0"]
5099 numcols = len(select_0._all_selected_columns)
5100
5101 if len(compile_state.columns_plus_names) != numcols:
5102 raise exc.CompileError(
5103 "All selectables passed to "
5104 "CompoundSelect must have identical numbers of "
5105 "columns; select #%d has %d columns, select "
5106 "#%d has %d"
5107 % (
5108 1,
5109 numcols,
5110 compound_index + 1,
5111 len(select._all_selected_columns),
5112 )
5113 )
5114
5115 if asfrom and not lateral:
5116 froms = compile_state._get_display_froms(
5117 explicit_correlate_froms=correlate_froms.difference(
5118 asfrom_froms
5119 ),
5120 implicit_correlate_froms=(),
5121 )
5122 else:
5123 froms = compile_state._get_display_froms(
5124 explicit_correlate_froms=correlate_froms,
5125 implicit_correlate_froms=asfrom_froms,
5126 )
5127
5128 new_correlate_froms = set(_from_objects(*froms))
5129 all_correlate_froms = new_correlate_froms.union(correlate_froms)
5130
5131 new_entry: _CompilerStackEntry = {
5132 "asfrom_froms": new_correlate_froms,
5133 "correlate_froms": all_correlate_froms,
5134 "selectable": select,
5135 "compile_state": compile_state,
5136 }
5137 self.stack.append(new_entry)
5138
5139 return froms
5140
5141 def _compose_select_body(
5142 self,
5143 text,
5144 select,
5145 compile_state,
5146 inner_columns,
5147 froms,
5148 byfrom,
5149 toplevel,
5150 kwargs,
5151 ):
5152 text += ", ".join(inner_columns)
5153
5154 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
5155 from_linter = FromLinter({}, set())
5156 warn_linting = self.linting & WARN_LINTING
5157 if toplevel:
5158 self.from_linter = from_linter
5159 else:
5160 from_linter = None
5161 warn_linting = False
5162
5163 # adjust the whitespace for no inner columns, part of #9440,
5164 # so that a no-col SELECT comes out as "SELECT WHERE..." or
5165 # "SELECT FROM ...".
5166 # while it would be better to have built the SELECT starting string
5167 # without trailing whitespace first, then add whitespace only if inner
5168 # cols were present, this breaks compatibility with various custom
5169 # compilation schemes that are currently being tested.
5170 if not inner_columns:
5171 text = text.rstrip()
5172
5173 if froms:
5174 text += " \nFROM "
5175
5176 if select._hints:
5177 text += ", ".join(
5178 [
5179 f._compiler_dispatch(
5180 self,
5181 asfrom=True,
5182 fromhints=byfrom,
5183 from_linter=from_linter,
5184 **kwargs,
5185 )
5186 for f in froms
5187 ]
5188 )
5189 else:
5190 text += ", ".join(
5191 [
5192 f._compiler_dispatch(
5193 self,
5194 asfrom=True,
5195 from_linter=from_linter,
5196 **kwargs,
5197 )
5198 for f in froms
5199 ]
5200 )
5201 else:
5202 text += self.default_from()
5203
5204 if select._where_criteria:
5205 t = self._generate_delimited_and_list(
5206 select._where_criteria, from_linter=from_linter, **kwargs
5207 )
5208 if t:
5209 text += " \nWHERE " + t
5210
5211 if warn_linting:
5212 assert from_linter is not None
5213 from_linter.warn()
5214
5215 if select._group_by_clauses:
5216 text += self.group_by_clause(select, **kwargs)
5217
5218 if select._having_criteria:
5219 t = self._generate_delimited_and_list(
5220 select._having_criteria, **kwargs
5221 )
5222 if t:
5223 text += " \nHAVING " + t
5224
5225 if select._post_criteria_clause is not None:
5226 pcc = self.process(select._post_criteria_clause, **kwargs)
5227 if pcc is not None:
5228 text += " \n" + pcc
5229
5230 if select._order_by_clauses:
5231 text += self.order_by_clause(select, **kwargs)
5232
5233 if select._has_row_limiting_clause:
5234 text += self._row_limit_clause(select, **kwargs)
5235
5236 if select._for_update_arg is not None:
5237 text += self.for_update_clause(select, **kwargs)
5238
5239 return text
5240
5241 def _generate_prefixes(self, stmt, prefixes, **kw):
5242 clause = " ".join(
5243 prefix._compiler_dispatch(self, **kw)
5244 for prefix, dialect_name in prefixes
5245 if dialect_name in (None, "*") or dialect_name == self.dialect.name
5246 )
5247 if clause:
5248 clause += " "
5249 return clause
5250
5251 def _render_cte_clause(
5252 self,
5253 nesting_level=None,
5254 include_following_stack=False,
5255 ):
5256 """
5257 include_following_stack
5258 Also render the nesting CTEs on the next stack. Useful for
5259 SQL structures like UNION or INSERT that can wrap SELECT
5260 statements containing nesting CTEs.
5261 """
5262 if not self.ctes:
5263 return ""
5264
5265 ctes: MutableMapping[CTE, str]
5266
5267 if nesting_level and nesting_level > 1:
5268 ctes = util.OrderedDict()
5269 for cte in list(self.ctes.keys()):
5270 cte_level, cte_name, cte_opts = self.level_name_by_cte[
5271 cte._get_reference_cte()
5272 ]
5273 nesting = cte.nesting or cte_opts.nesting
5274 is_rendered_level = cte_level == nesting_level or (
5275 include_following_stack and cte_level == nesting_level + 1
5276 )
5277 if not (nesting and is_rendered_level):
5278 continue
5279
5280 ctes[cte] = self.ctes[cte]
5281
5282 else:
5283 ctes = self.ctes
5284
5285 if not ctes:
5286 return ""
5287 ctes_recursive = any([cte.recursive for cte in ctes])
5288
5289 cte_text = self.get_cte_preamble(ctes_recursive) + " "
5290 cte_text += ", \n".join([txt for txt in ctes.values()])
5291 cte_text += "\n "
5292
5293 if nesting_level and nesting_level > 1:
5294 for cte in list(ctes.keys()):
5295 cte_level, cte_name, cte_opts = self.level_name_by_cte[
5296 cte._get_reference_cte()
5297 ]
5298 del self.ctes[cte]
5299 del self.ctes_by_level_name[(cte_level, cte_name)]
5300 del self.level_name_by_cte[cte._get_reference_cte()]
5301
5302 return cte_text
5303
5304 def get_cte_preamble(self, recursive):
5305 if recursive:
5306 return "WITH RECURSIVE"
5307 else:
5308 return "WITH"
5309
5310 def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str:
5311 """Called when building a ``SELECT`` statement, position is just
5312 before column list.
5313
5314 """
5315 if select._distinct_on:
5316 util.warn_deprecated(
5317 "DISTINCT ON is currently supported only by the PostgreSQL "
5318 "dialect. Use of DISTINCT ON for other backends is currently "
5319 "silently ignored, however this usage is deprecated, and will "
5320 "raise CompileError in a future release for all backends "
5321 "that do not support this syntax.",
5322 version="1.4",
5323 )
5324 return "DISTINCT " if select._distinct else ""
5325
5326 def group_by_clause(self, select, **kw):
5327 """allow dialects to customize how GROUP BY is rendered."""
5328
5329 group_by = self._generate_delimited_list(
5330 select._group_by_clauses, OPERATORS[operators.comma_op], **kw
5331 )
5332 if group_by:
5333 return " GROUP BY " + group_by
5334 else:
5335 return ""
5336
5337 def order_by_clause(self, select, **kw):
5338 """allow dialects to customize how ORDER BY is rendered."""
5339
5340 order_by = self._generate_delimited_list(
5341 select._order_by_clauses, OPERATORS[operators.comma_op], **kw
5342 )
5343
5344 if order_by:
5345 return " ORDER BY " + order_by
5346 else:
5347 return ""
5348
5349 def for_update_clause(self, select, **kw):
5350 return " FOR UPDATE"
5351
5352 def returning_clause(
5353 self,
5354 stmt: UpdateBase,
5355 returning_cols: Sequence[_ColumnsClauseElement],
5356 *,
5357 populate_result_map: bool,
5358 **kw: Any,
5359 ) -> str:
5360 columns = [
5361 self._label_returning_column(
5362 stmt,
5363 column,
5364 populate_result_map,
5365 fallback_label_name=fallback_label_name,
5366 column_is_repeated=repeated,
5367 name=name,
5368 proxy_name=proxy_name,
5369 **kw,
5370 )
5371 for (
5372 name,
5373 proxy_name,
5374 fallback_label_name,
5375 column,
5376 repeated,
5377 ) in stmt._generate_columns_plus_names(
5378 True, cols=base._select_iterables(returning_cols)
5379 )
5380 ]
5381
5382 return "RETURNING " + ", ".join(columns)
5383
5384 def limit_clause(self, select, **kw):
5385 text = ""
5386 if select._limit_clause is not None:
5387 text += "\n LIMIT " + self.process(select._limit_clause, **kw)
5388 if select._offset_clause is not None:
5389 if select._limit_clause is None:
5390 text += "\n LIMIT -1"
5391 text += " OFFSET " + self.process(select._offset_clause, **kw)
5392 return text
5393
5394 def fetch_clause(
5395 self,
5396 select,
5397 fetch_clause=None,
5398 require_offset=False,
5399 use_literal_execute_for_simple_int=False,
5400 **kw,
5401 ):
5402 if fetch_clause is None:
5403 fetch_clause = select._fetch_clause
5404 fetch_clause_options = select._fetch_clause_options
5405 else:
5406 fetch_clause_options = {"percent": False, "with_ties": False}
5407
5408 text = ""
5409
5410 if select._offset_clause is not None:
5411 offset_clause = select._offset_clause
5412 if (
5413 use_literal_execute_for_simple_int
5414 and select._simple_int_clause(offset_clause)
5415 ):
5416 offset_clause = offset_clause.render_literal_execute()
5417 offset_str = self.process(offset_clause, **kw)
5418 text += "\n OFFSET %s ROWS" % offset_str
5419 elif require_offset:
5420 text += "\n OFFSET 0 ROWS"
5421
5422 if fetch_clause is not None:
5423 if (
5424 use_literal_execute_for_simple_int
5425 and select._simple_int_clause(fetch_clause)
5426 ):
5427 fetch_clause = fetch_clause.render_literal_execute()
5428 text += "\n FETCH FIRST %s%s ROWS %s" % (
5429 self.process(fetch_clause, **kw),
5430 " PERCENT" if fetch_clause_options["percent"] else "",
5431 "WITH TIES" if fetch_clause_options["with_ties"] else "ONLY",
5432 )
5433 return text
5434
5435 def visit_table(
5436 self,
5437 table,
5438 asfrom=False,
5439 iscrud=False,
5440 ashint=False,
5441 fromhints=None,
5442 use_schema=True,
5443 from_linter=None,
5444 ambiguous_table_name_map=None,
5445 enclosing_alias=None,
5446 **kwargs,
5447 ):
5448 if from_linter:
5449 from_linter.froms[table] = table.fullname
5450
5451 if asfrom or ashint:
5452 effective_schema = self.preparer.schema_for_object(table)
5453
5454 if use_schema and effective_schema:
5455 ret = (
5456 self.preparer.quote_schema(effective_schema)
5457 + "."
5458 + self.preparer.quote(table.name)
5459 )
5460 else:
5461 ret = self.preparer.quote(table.name)
5462
5463 if (
5464 (
5465 enclosing_alias is None
5466 or enclosing_alias.element is not table
5467 )
5468 and not effective_schema
5469 and ambiguous_table_name_map
5470 and table.name in ambiguous_table_name_map
5471 ):
5472 anon_name = self._truncated_identifier(
5473 "alias", ambiguous_table_name_map[table.name]
5474 )
5475
5476 ret = ret + self.get_render_as_alias_suffix(
5477 self.preparer.format_alias(None, anon_name)
5478 )
5479
5480 if fromhints and table in fromhints:
5481 ret = self.format_from_hint_text(
5482 ret, table, fromhints[table], iscrud
5483 )
5484 return ret
5485 else:
5486 return ""
5487
5488 def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
5489 if from_linter:
5490 from_linter.edges.update(
5491 itertools.product(
5492 _de_clone(join.left._from_objects),
5493 _de_clone(join.right._from_objects),
5494 )
5495 )
5496
5497 if join.full:
5498 join_type = " FULL OUTER JOIN "
5499 elif join.isouter:
5500 join_type = " LEFT OUTER JOIN "
5501 else:
5502 join_type = " JOIN "
5503 return (
5504 join.left._compiler_dispatch(
5505 self, asfrom=True, from_linter=from_linter, **kwargs
5506 )
5507 + join_type
5508 + join.right._compiler_dispatch(
5509 self, asfrom=True, from_linter=from_linter, **kwargs
5510 )
5511 + " ON "
5512 # TODO: likely need asfrom=True here?
5513 + join.onclause._compiler_dispatch(
5514 self, from_linter=from_linter, **kwargs
5515 )
5516 )
5517
5518 def _setup_crud_hints(self, stmt, table_text):
5519 dialect_hints = {
5520 table: hint_text
5521 for (table, dialect), hint_text in stmt._hints.items()
5522 if dialect in ("*", self.dialect.name)
5523 }
5524 if stmt.table in dialect_hints:
5525 table_text = self.format_from_hint_text(
5526 table_text, stmt.table, dialect_hints[stmt.table], True
5527 )
5528 return dialect_hints, table_text
5529
5530 # within the realm of "insertmanyvalues sentinel columns",
5531 # these lookups match different kinds of Column() configurations
5532 # to specific backend capabilities. they are broken into two
5533 # lookups, one for autoincrement columns and the other for non
5534 # autoincrement columns
5535 _sentinel_col_non_autoinc_lookup = util.immutabledict(
5536 {
5537 _SentinelDefaultCharacterization.CLIENTSIDE: (
5538 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5539 ),
5540 _SentinelDefaultCharacterization.SENTINEL_DEFAULT: (
5541 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5542 ),
5543 _SentinelDefaultCharacterization.NONE: (
5544 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5545 ),
5546 _SentinelDefaultCharacterization.IDENTITY: (
5547 InsertmanyvaluesSentinelOpts.IDENTITY
5548 ),
5549 _SentinelDefaultCharacterization.SEQUENCE: (
5550 InsertmanyvaluesSentinelOpts.SEQUENCE
5551 ),
5552 }
5553 )
5554 _sentinel_col_autoinc_lookup = _sentinel_col_non_autoinc_lookup.union(
5555 {
5556 _SentinelDefaultCharacterization.NONE: (
5557 InsertmanyvaluesSentinelOpts.AUTOINCREMENT
5558 ),
5559 }
5560 )
5561
5562 def _get_sentinel_column_for_table(
5563 self, table: Table
5564 ) -> Optional[Sequence[Column[Any]]]:
5565 """given a :class:`.Table`, return a usable sentinel column or
5566 columns for this dialect if any.
5567
5568 Return None if no sentinel columns could be identified, or raise an
5569 error if a column was marked as a sentinel explicitly but isn't
5570 compatible with this dialect.
5571
5572 """
5573
5574 sentinel_opts = self.dialect.insertmanyvalues_implicit_sentinel
5575 sentinel_characteristics = table._sentinel_column_characteristics
5576
5577 sent_cols = sentinel_characteristics.columns
5578
5579 if sent_cols is None:
5580 return None
5581
5582 if sentinel_characteristics.is_autoinc:
5583 bitmask = self._sentinel_col_autoinc_lookup.get(
5584 sentinel_characteristics.default_characterization, 0
5585 )
5586 else:
5587 bitmask = self._sentinel_col_non_autoinc_lookup.get(
5588 sentinel_characteristics.default_characterization, 0
5589 )
5590
5591 if sentinel_opts & bitmask:
5592 return sent_cols
5593
5594 if sentinel_characteristics.is_explicit:
5595 # a column was explicitly marked as insert_sentinel=True,
5596 # however it is not compatible with this dialect. they should
5597 # not indicate this column as a sentinel if they need to include
5598 # this dialect.
5599
5600 # TODO: do we want non-primary key explicit sentinel cols
5601 # that can gracefully degrade for some backends?
5602 # insert_sentinel="degrade" perhaps. not for the initial release.
5603 # I am hoping people are generally not dealing with this sentinel
5604 # business at all.
5605
5606 # if is_explicit is True, there will be only one sentinel column.
5607
5608 raise exc.InvalidRequestError(
5609 f"Column {sent_cols[0]} can't be explicitly "
5610 "marked as a sentinel column when using the "
5611 f"{self.dialect.name} dialect, as the "
5612 "particular type of default generation on this column is "
5613 "not currently compatible with this dialect's specific "
5614 f"INSERT..RETURNING syntax which can receive the "
5615 "server-generated value in "
5616 "a deterministic way. To remove this error, remove "
5617 "insert_sentinel=True from primary key autoincrement "
5618 "columns; these columns are automatically used as "
5619 "sentinels for supported dialects in any case."
5620 )
5621
5622 return None
5623
5624 def _deliver_insertmanyvalues_batches(
5625 self,
5626 statement: str,
5627 parameters: _DBAPIMultiExecuteParams,
5628 compiled_parameters: List[_MutableCoreSingleExecuteParams],
5629 generic_setinputsizes: Optional[_GenericSetInputSizesType],
5630 batch_size: int,
5631 sort_by_parameter_order: bool,
5632 schema_translate_map: Optional[SchemaTranslateMapType],
5633 ) -> Iterator[_InsertManyValuesBatch]:
5634 imv = self._insertmanyvalues
5635 assert imv is not None
5636
5637 if not imv.sentinel_param_keys:
5638 _sentinel_from_params = None
5639 else:
5640 _sentinel_from_params = operator.itemgetter(
5641 *imv.sentinel_param_keys
5642 )
5643
5644 lenparams = len(parameters)
5645 if imv.is_default_expr and not self.dialect.supports_default_metavalue:
5646 # backend doesn't support
5647 # INSERT INTO table (pk_col) VALUES (DEFAULT), (DEFAULT), ...
5648 # at the moment this is basically SQL Server due to
5649 # not being able to use DEFAULT for identity column
5650 # just yield out that many single statements! still
5651 # faster than a whole connection.execute() call ;)
5652 #
5653 # note we still are taking advantage of the fact that we know
5654 # we are using RETURNING. The generalized approach of fetching
5655 # cursor.lastrowid etc. still goes through the more heavyweight
5656 # "ExecutionContext per statement" system as it isn't usable
5657 # as a generic "RETURNING" approach
5658 use_row_at_a_time = True
5659 downgraded = False
5660 elif not self.dialect.supports_multivalues_insert or (
5661 sort_by_parameter_order
5662 and self._result_columns
5663 and (imv.sentinel_columns is None or imv.includes_upsert_behaviors)
5664 ):
5665 # deterministic order was requested and the compiler could
5666 # not organize sentinel columns for this dialect/statement.
5667 # use row at a time
5668 use_row_at_a_time = True
5669 downgraded = True
5670 else:
5671 use_row_at_a_time = False
5672 downgraded = False
5673
5674 if use_row_at_a_time:
5675 for batchnum, (param, compiled_param) in enumerate(
5676 cast(
5677 "Sequence[Tuple[_DBAPISingleExecuteParams, _MutableCoreSingleExecuteParams]]", # noqa: E501
5678 zip(parameters, compiled_parameters),
5679 ),
5680 1,
5681 ):
5682 yield _InsertManyValuesBatch(
5683 statement,
5684 param,
5685 generic_setinputsizes,
5686 [param],
5687 (
5688 [_sentinel_from_params(compiled_param)]
5689 if _sentinel_from_params
5690 else []
5691 ),
5692 1,
5693 batchnum,
5694 lenparams,
5695 sort_by_parameter_order,
5696 downgraded,
5697 )
5698 return
5699
5700 if schema_translate_map:
5701 rst = functools.partial(
5702 self.preparer._render_schema_translates,
5703 schema_translate_map=schema_translate_map,
5704 )
5705 else:
5706 rst = None
5707
5708 imv_single_values_expr = imv.single_values_expr
5709 if rst:
5710 imv_single_values_expr = rst(imv_single_values_expr)
5711
5712 executemany_values = f"({imv_single_values_expr})"
5713 statement = statement.replace(executemany_values, "__EXECMANY_TOKEN__")
5714
5715 # Use optional insertmanyvalues_max_parameters
5716 # to further shrink the batch size so that there are no more than
5717 # insertmanyvalues_max_parameters params.
5718 # Currently used by SQL Server, which limits statements to 2100 bound
5719 # parameters (actually 2099).
5720 max_params = self.dialect.insertmanyvalues_max_parameters
5721 if max_params:
5722 total_num_of_params = len(self.bind_names)
5723 num_params_per_batch = len(imv.insert_crud_params)
5724 num_params_outside_of_batch = (
5725 total_num_of_params - num_params_per_batch
5726 )
5727 batch_size = min(
5728 batch_size,
5729 (
5730 (max_params - num_params_outside_of_batch)
5731 // num_params_per_batch
5732 ),
5733 )
5734
5735 batches = cast("List[Sequence[Any]]", list(parameters))
5736 compiled_batches = cast(
5737 "List[Sequence[Any]]", list(compiled_parameters)
5738 )
5739
5740 processed_setinputsizes: Optional[_GenericSetInputSizesType] = None
5741 batchnum = 1
5742 total_batches = lenparams // batch_size + (
5743 1 if lenparams % batch_size else 0
5744 )
5745
5746 insert_crud_params = imv.insert_crud_params
5747 assert insert_crud_params is not None
5748
5749 if rst:
5750 insert_crud_params = [
5751 (col, key, rst(expr), st)
5752 for col, key, expr, st in insert_crud_params
5753 ]
5754
5755 escaped_bind_names: Mapping[str, str]
5756 expand_pos_lower_index = expand_pos_upper_index = 0
5757
5758 if not self.positional:
5759 if self.escaped_bind_names:
5760 escaped_bind_names = self.escaped_bind_names
5761 else:
5762 escaped_bind_names = {}
5763
5764 all_keys = set(parameters[0])
5765
5766 def apply_placeholders(keys, formatted):
5767 for key in keys:
5768 key = escaped_bind_names.get(key, key)
5769 formatted = formatted.replace(
5770 self.bindtemplate % {"name": key},
5771 self.bindtemplate
5772 % {"name": f"{key}__EXECMANY_INDEX__"},
5773 )
5774 return formatted
5775
5776 if imv.embed_values_counter:
5777 imv_values_counter = ", _IMV_VALUES_COUNTER"
5778 else:
5779 imv_values_counter = ""
5780 formatted_values_clause = f"""({', '.join(
5781 apply_placeholders(bind_keys, formatted)
5782 for _, _, formatted, bind_keys in insert_crud_params
5783 )}{imv_values_counter})"""
5784
5785 keys_to_replace = all_keys.intersection(
5786 escaped_bind_names.get(key, key)
5787 for _, _, _, bind_keys in insert_crud_params
5788 for key in bind_keys
5789 )
5790 base_parameters = {
5791 key: parameters[0][key]
5792 for key in all_keys.difference(keys_to_replace)
5793 }
5794 executemany_values_w_comma = ""
5795 else:
5796 formatted_values_clause = ""
5797 keys_to_replace = set()
5798 base_parameters = {}
5799
5800 if imv.embed_values_counter:
5801 executemany_values_w_comma = (
5802 f"({imv_single_values_expr}, _IMV_VALUES_COUNTER), "
5803 )
5804 else:
5805 executemany_values_w_comma = f"({imv_single_values_expr}), "
5806
5807 all_names_we_will_expand: Set[str] = set()
5808 for elem in imv.insert_crud_params:
5809 all_names_we_will_expand.update(elem[3])
5810
5811 # get the start and end position in a particular list
5812 # of parameters where we will be doing the "expanding".
5813 # statements can have params on either side or both sides,
5814 # given RETURNING and CTEs
5815 if all_names_we_will_expand:
5816 positiontup = self.positiontup
5817 assert positiontup is not None
5818
5819 all_expand_positions = {
5820 idx
5821 for idx, name in enumerate(positiontup)
5822 if name in all_names_we_will_expand
5823 }
5824 expand_pos_lower_index = min(all_expand_positions)
5825 expand_pos_upper_index = max(all_expand_positions) + 1
5826 assert (
5827 len(all_expand_positions)
5828 == expand_pos_upper_index - expand_pos_lower_index
5829 )
5830
5831 if self._numeric_binds:
5832 escaped = re.escape(self._numeric_binds_identifier_char)
5833 executemany_values_w_comma = re.sub(
5834 rf"{escaped}\d+", "%s", executemany_values_w_comma
5835 )
5836
5837 while batches:
5838 batch = batches[0:batch_size]
5839 compiled_batch = compiled_batches[0:batch_size]
5840
5841 batches[0:batch_size] = []
5842 compiled_batches[0:batch_size] = []
5843
5844 if batches:
5845 current_batch_size = batch_size
5846 else:
5847 current_batch_size = len(batch)
5848
5849 if generic_setinputsizes:
5850 # if setinputsizes is present, expand this collection to
5851 # suit the batch length as well
5852 # currently this will be mssql+pyodbc for internal dialects
5853 processed_setinputsizes = [
5854 (new_key, len_, typ)
5855 for new_key, len_, typ in (
5856 (f"{key}_{index}", len_, typ)
5857 for index in range(current_batch_size)
5858 for key, len_, typ in generic_setinputsizes
5859 )
5860 ]
5861
5862 replaced_parameters: Any
5863 if self.positional:
5864 num_ins_params = imv.num_positional_params_counted
5865
5866 batch_iterator: Iterable[Sequence[Any]]
5867 extra_params_left: Sequence[Any]
5868 extra_params_right: Sequence[Any]
5869
5870 if num_ins_params == len(batch[0]):
5871 extra_params_left = extra_params_right = ()
5872 batch_iterator = batch
5873 else:
5874 extra_params_left = batch[0][:expand_pos_lower_index]
5875 extra_params_right = batch[0][expand_pos_upper_index:]
5876 batch_iterator = (
5877 b[expand_pos_lower_index:expand_pos_upper_index]
5878 for b in batch
5879 )
5880
5881 if imv.embed_values_counter:
5882 expanded_values_string = (
5883 "".join(
5884 executemany_values_w_comma.replace(
5885 "_IMV_VALUES_COUNTER", str(i)
5886 )
5887 for i, _ in enumerate(batch)
5888 )
5889 )[:-2]
5890 else:
5891 expanded_values_string = (
5892 (executemany_values_w_comma * current_batch_size)
5893 )[:-2]
5894
5895 if self._numeric_binds and num_ins_params > 0:
5896 # numeric will always number the parameters inside of
5897 # VALUES (and thus order self.positiontup) to be higher
5898 # than non-VALUES parameters, no matter where in the
5899 # statement those non-VALUES parameters appear (this is
5900 # ensured in _process_numeric by numbering first all
5901 # params that are not in _values_bindparam)
5902 # therefore all extra params are always
5903 # on the left side and numbered lower than the VALUES
5904 # parameters
5905 assert not extra_params_right
5906
5907 start = expand_pos_lower_index + 1
5908 end = num_ins_params * (current_batch_size) + start
5909
5910 # need to format here, since statement may contain
5911 # unescaped %, while values_string contains just (%s, %s)
5912 positions = tuple(
5913 f"{self._numeric_binds_identifier_char}{i}"
5914 for i in range(start, end)
5915 )
5916 expanded_values_string = expanded_values_string % positions
5917
5918 replaced_statement = statement.replace(
5919 "__EXECMANY_TOKEN__", expanded_values_string
5920 )
5921
5922 replaced_parameters = tuple(
5923 itertools.chain.from_iterable(batch_iterator)
5924 )
5925
5926 replaced_parameters = (
5927 extra_params_left
5928 + replaced_parameters
5929 + extra_params_right
5930 )
5931
5932 else:
5933 replaced_values_clauses = []
5934 replaced_parameters = base_parameters.copy()
5935
5936 for i, param in enumerate(batch):
5937 fmv = formatted_values_clause.replace(
5938 "EXECMANY_INDEX__", str(i)
5939 )
5940 if imv.embed_values_counter:
5941 fmv = fmv.replace("_IMV_VALUES_COUNTER", str(i))
5942
5943 replaced_values_clauses.append(fmv)
5944 replaced_parameters.update(
5945 {f"{key}__{i}": param[key] for key in keys_to_replace}
5946 )
5947
5948 replaced_statement = statement.replace(
5949 "__EXECMANY_TOKEN__",
5950 ", ".join(replaced_values_clauses),
5951 )
5952
5953 yield _InsertManyValuesBatch(
5954 replaced_statement,
5955 replaced_parameters,
5956 processed_setinputsizes,
5957 batch,
5958 (
5959 [_sentinel_from_params(cb) for cb in compiled_batch]
5960 if _sentinel_from_params
5961 else []
5962 ),
5963 current_batch_size,
5964 batchnum,
5965 total_batches,
5966 sort_by_parameter_order,
5967 False,
5968 )
5969 batchnum += 1
5970
5971 def visit_insert(
5972 self, insert_stmt, visited_bindparam=None, visiting_cte=None, **kw
5973 ):
5974 compile_state = insert_stmt._compile_state_factory(
5975 insert_stmt, self, **kw
5976 )
5977 insert_stmt = compile_state.statement
5978
5979 if visiting_cte is not None:
5980 kw["visiting_cte"] = visiting_cte
5981 toplevel = False
5982 else:
5983 toplevel = not self.stack
5984
5985 if toplevel:
5986 self.isinsert = True
5987 if not self.dml_compile_state:
5988 self.dml_compile_state = compile_state
5989 if not self.compile_state:
5990 self.compile_state = compile_state
5991
5992 self.stack.append(
5993 {
5994 "correlate_froms": set(),
5995 "asfrom_froms": set(),
5996 "selectable": insert_stmt,
5997 }
5998 )
5999
6000 counted_bindparam = 0
6001
6002 # reset any incoming "visited_bindparam" collection
6003 visited_bindparam = None
6004
6005 # for positional, insertmanyvalues needs to know how many
6006 # bound parameters are in the VALUES sequence; there's no simple
6007 # rule because default expressions etc. can have zero or more
6008 # params inside them. After multiple attempts to figure this out,
6009 # this very simplistic "count after" works and is
6010 # likely the least amount of callcounts, though looks clumsy
6011 if self.positional and visiting_cte is None:
6012 # if we are inside a CTE, don't count parameters
6013 # here since they wont be for insertmanyvalues. keep
6014 # visited_bindparam at None so no counting happens.
6015 # see #9173
6016 visited_bindparam = []
6017
6018 crud_params_struct = crud._get_crud_params(
6019 self,
6020 insert_stmt,
6021 compile_state,
6022 toplevel,
6023 visited_bindparam=visited_bindparam,
6024 **kw,
6025 )
6026
6027 if self.positional and visited_bindparam is not None:
6028 counted_bindparam = len(visited_bindparam)
6029 if self._numeric_binds:
6030 if self._values_bindparam is not None:
6031 self._values_bindparam += visited_bindparam
6032 else:
6033 self._values_bindparam = visited_bindparam
6034
6035 crud_params_single = crud_params_struct.single_params
6036
6037 if (
6038 not crud_params_single
6039 and not self.dialect.supports_default_values
6040 and not self.dialect.supports_default_metavalue
6041 and not self.dialect.supports_empty_insert
6042 ):
6043 raise exc.CompileError(
6044 "The '%s' dialect with current database "
6045 "version settings does not support empty "
6046 "inserts." % self.dialect.name
6047 )
6048
6049 if compile_state._has_multi_parameters:
6050 if not self.dialect.supports_multivalues_insert:
6051 raise exc.CompileError(
6052 "The '%s' dialect with current database "
6053 "version settings does not support "
6054 "in-place multirow inserts." % self.dialect.name
6055 )
6056 elif (
6057 self.implicit_returning or insert_stmt._returning
6058 ) and insert_stmt._sort_by_parameter_order:
6059 raise exc.CompileError(
6060 "RETURNING cannot be determinstically sorted when "
6061 "using an INSERT which includes multi-row values()."
6062 )
6063 crud_params_single = crud_params_struct.single_params
6064 else:
6065 crud_params_single = crud_params_struct.single_params
6066
6067 preparer = self.preparer
6068 supports_default_values = self.dialect.supports_default_values
6069
6070 text = "INSERT "
6071
6072 if insert_stmt._prefixes:
6073 text += self._generate_prefixes(
6074 insert_stmt, insert_stmt._prefixes, **kw
6075 )
6076
6077 text += "INTO "
6078 table_text = preparer.format_table(insert_stmt.table)
6079
6080 if insert_stmt._hints:
6081 _, table_text = self._setup_crud_hints(insert_stmt, table_text)
6082
6083 if insert_stmt._independent_ctes:
6084 self._dispatch_independent_ctes(insert_stmt, kw)
6085
6086 text += table_text
6087
6088 if crud_params_single or not supports_default_values:
6089 text += " (%s)" % ", ".join(
6090 [expr for _, expr, _, _ in crud_params_single]
6091 )
6092
6093 # look for insertmanyvalues attributes that would have been configured
6094 # by crud.py as it scanned through the columns to be part of the
6095 # INSERT
6096 use_insertmanyvalues = crud_params_struct.use_insertmanyvalues
6097 named_sentinel_params: Optional[Sequence[str]] = None
6098 add_sentinel_cols = None
6099 implicit_sentinel = False
6100
6101 returning_cols = self.implicit_returning or insert_stmt._returning
6102 if returning_cols:
6103 add_sentinel_cols = crud_params_struct.use_sentinel_columns
6104 if add_sentinel_cols is not None:
6105 assert use_insertmanyvalues
6106
6107 # search for the sentinel column explicitly present
6108 # in the INSERT columns list, and additionally check that
6109 # this column has a bound parameter name set up that's in the
6110 # parameter list. If both of these cases are present, it means
6111 # we will have a client side value for the sentinel in each
6112 # parameter set.
6113
6114 _params_by_col = {
6115 col: param_names
6116 for col, _, _, param_names in crud_params_single
6117 }
6118 named_sentinel_params = []
6119 for _add_sentinel_col in add_sentinel_cols:
6120 if _add_sentinel_col not in _params_by_col:
6121 named_sentinel_params = None
6122 break
6123 param_name = self._within_exec_param_key_getter(
6124 _add_sentinel_col
6125 )
6126 if param_name not in _params_by_col[_add_sentinel_col]:
6127 named_sentinel_params = None
6128 break
6129 named_sentinel_params.append(param_name)
6130
6131 if named_sentinel_params is None:
6132 # if we are not going to have a client side value for
6133 # the sentinel in the parameter set, that means it's
6134 # an autoincrement, an IDENTITY, or a server-side SQL
6135 # expression like nextval('seqname'). So this is
6136 # an "implicit" sentinel; we will look for it in
6137 # RETURNING
6138 # only, and then sort on it. For this case on PG,
6139 # SQL Server we have to use a special INSERT form
6140 # that guarantees the server side function lines up with
6141 # the entries in the VALUES.
6142 if (
6143 self.dialect.insertmanyvalues_implicit_sentinel
6144 & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT
6145 ):
6146 implicit_sentinel = True
6147 else:
6148 # here, we are not using a sentinel at all
6149 # and we are likely the SQLite dialect.
6150 # The first add_sentinel_col that we have should not
6151 # be marked as "insert_sentinel=True". if it was,
6152 # an error should have been raised in
6153 # _get_sentinel_column_for_table.
6154 assert not add_sentinel_cols[0]._insert_sentinel, (
6155 "sentinel selection rules should have prevented "
6156 "us from getting here for this dialect"
6157 )
6158
6159 # always put the sentinel columns last. even if they are
6160 # in the returning list already, they will be there twice
6161 # then.
6162 returning_cols = list(returning_cols) + list(add_sentinel_cols)
6163
6164 returning_clause = self.returning_clause(
6165 insert_stmt,
6166 returning_cols,
6167 populate_result_map=toplevel,
6168 )
6169
6170 if self.returning_precedes_values:
6171 text += " " + returning_clause
6172
6173 else:
6174 returning_clause = None
6175
6176 if insert_stmt.select is not None:
6177 # placed here by crud.py
6178 select_text = self.process(
6179 self.stack[-1]["insert_from_select"], insert_into=True, **kw
6180 )
6181
6182 if self.ctes and self.dialect.cte_follows_insert:
6183 nesting_level = len(self.stack) if not toplevel else None
6184 text += " %s%s" % (
6185 self._render_cte_clause(
6186 nesting_level=nesting_level,
6187 include_following_stack=True,
6188 ),
6189 select_text,
6190 )
6191 else:
6192 text += " %s" % select_text
6193 elif not crud_params_single and supports_default_values:
6194 text += " DEFAULT VALUES"
6195 if use_insertmanyvalues:
6196 self._insertmanyvalues = _InsertManyValues(
6197 True,
6198 self.dialect.default_metavalue_token,
6199 cast(
6200 "List[crud._CrudParamElementStr]", crud_params_single
6201 ),
6202 counted_bindparam,
6203 sort_by_parameter_order=(
6204 insert_stmt._sort_by_parameter_order
6205 ),
6206 includes_upsert_behaviors=(
6207 insert_stmt._post_values_clause is not None
6208 ),
6209 sentinel_columns=add_sentinel_cols,
6210 num_sentinel_columns=(
6211 len(add_sentinel_cols) if add_sentinel_cols else 0
6212 ),
6213 implicit_sentinel=implicit_sentinel,
6214 )
6215 elif compile_state._has_multi_parameters:
6216 text += " VALUES %s" % (
6217 ", ".join(
6218 "(%s)"
6219 % (", ".join(value for _, _, value, _ in crud_param_set))
6220 for crud_param_set in crud_params_struct.all_multi_params
6221 ),
6222 )
6223 else:
6224 insert_single_values_expr = ", ".join(
6225 [
6226 value
6227 for _, _, value, _ in cast(
6228 "List[crud._CrudParamElementStr]",
6229 crud_params_single,
6230 )
6231 ]
6232 )
6233
6234 if use_insertmanyvalues:
6235 if (
6236 implicit_sentinel
6237 and (
6238 self.dialect.insertmanyvalues_implicit_sentinel
6239 & InsertmanyvaluesSentinelOpts.USE_INSERT_FROM_SELECT
6240 )
6241 # this is checking if we have
6242 # INSERT INTO table (id) VALUES (DEFAULT).
6243 and not (crud_params_struct.is_default_metavalue_only)
6244 ):
6245 # if we have a sentinel column that is server generated,
6246 # then for selected backends render the VALUES list as a
6247 # subquery. This is the orderable form supported by
6248 # PostgreSQL and SQL Server.
6249 embed_sentinel_value = True
6250
6251 render_bind_casts = (
6252 self.dialect.insertmanyvalues_implicit_sentinel
6253 & InsertmanyvaluesSentinelOpts.RENDER_SELECT_COL_CASTS
6254 )
6255
6256 colnames = ", ".join(
6257 f"p{i}" for i, _ in enumerate(crud_params_single)
6258 )
6259
6260 if render_bind_casts:
6261 # render casts for the SELECT list. For PG, we are
6262 # already rendering bind casts in the parameter list,
6263 # selectively for the more "tricky" types like ARRAY.
6264 # however, even for the "easy" types, if the parameter
6265 # is NULL for every entry, PG gives up and says
6266 # "it must be TEXT", which fails for other easy types
6267 # like ints. So we cast on this side too.
6268 colnames_w_cast = ", ".join(
6269 self.render_bind_cast(
6270 col.type,
6271 col.type._unwrapped_dialect_impl(self.dialect),
6272 f"p{i}",
6273 )
6274 for i, (col, *_) in enumerate(crud_params_single)
6275 )
6276 else:
6277 colnames_w_cast = colnames
6278
6279 text += (
6280 f" SELECT {colnames_w_cast} FROM "
6281 f"(VALUES ({insert_single_values_expr})) "
6282 f"AS imp_sen({colnames}, sen_counter) "
6283 "ORDER BY sen_counter"
6284 )
6285 else:
6286 # otherwise, if no sentinel or backend doesn't support
6287 # orderable subquery form, use a plain VALUES list
6288 embed_sentinel_value = False
6289 text += f" VALUES ({insert_single_values_expr})"
6290
6291 self._insertmanyvalues = _InsertManyValues(
6292 is_default_expr=False,
6293 single_values_expr=insert_single_values_expr,
6294 insert_crud_params=cast(
6295 "List[crud._CrudParamElementStr]",
6296 crud_params_single,
6297 ),
6298 num_positional_params_counted=counted_bindparam,
6299 sort_by_parameter_order=(
6300 insert_stmt._sort_by_parameter_order
6301 ),
6302 includes_upsert_behaviors=(
6303 insert_stmt._post_values_clause is not None
6304 ),
6305 sentinel_columns=add_sentinel_cols,
6306 num_sentinel_columns=(
6307 len(add_sentinel_cols) if add_sentinel_cols else 0
6308 ),
6309 sentinel_param_keys=named_sentinel_params,
6310 implicit_sentinel=implicit_sentinel,
6311 embed_values_counter=embed_sentinel_value,
6312 )
6313
6314 else:
6315 text += f" VALUES ({insert_single_values_expr})"
6316
6317 if insert_stmt._post_values_clause is not None:
6318 post_values_clause = self.process(
6319 insert_stmt._post_values_clause, **kw
6320 )
6321 if post_values_clause:
6322 text += " " + post_values_clause
6323
6324 if returning_clause and not self.returning_precedes_values:
6325 text += " " + returning_clause
6326
6327 if self.ctes and not self.dialect.cte_follows_insert:
6328 nesting_level = len(self.stack) if not toplevel else None
6329 text = (
6330 self._render_cte_clause(
6331 nesting_level=nesting_level,
6332 include_following_stack=True,
6333 )
6334 + text
6335 )
6336
6337 self.stack.pop(-1)
6338
6339 return text
6340
6341 def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
6342 """Provide a hook to override the initial table clause
6343 in an UPDATE statement.
6344
6345 MySQL overrides this.
6346
6347 """
6348 kw["asfrom"] = True
6349 return from_table._compiler_dispatch(self, iscrud=True, **kw)
6350
6351 def update_from_clause(
6352 self, update_stmt, from_table, extra_froms, from_hints, **kw
6353 ):
6354 """Provide a hook to override the generation of an
6355 UPDATE..FROM clause.
6356 MySQL and MSSQL override this.
6357 """
6358 raise NotImplementedError(
6359 "This backend does not support multiple-table "
6360 "criteria within UPDATE"
6361 )
6362
6363 def update_post_criteria_clause(
6364 self, update_stmt: Update, **kw: Any
6365 ) -> Optional[str]:
6366 """provide a hook to override generation after the WHERE criteria
6367 in an UPDATE statement
6368
6369 .. versionadded:: 2.1
6370
6371 """
6372 if update_stmt._post_criteria_clause is not None:
6373 return self.process(
6374 update_stmt._post_criteria_clause,
6375 **kw,
6376 )
6377 else:
6378 return None
6379
6380 def delete_post_criteria_clause(
6381 self, delete_stmt: Delete, **kw: Any
6382 ) -> Optional[str]:
6383 """provide a hook to override generation after the WHERE criteria
6384 in a DELETE statement
6385
6386 .. versionadded:: 2.1
6387
6388 """
6389 if delete_stmt._post_criteria_clause is not None:
6390 return self.process(
6391 delete_stmt._post_criteria_clause,
6392 **kw,
6393 )
6394 else:
6395 return None
6396
6397 def visit_update(
6398 self,
6399 update_stmt: Update,
6400 visiting_cte: Optional[CTE] = None,
6401 **kw: Any,
6402 ) -> str:
6403 compile_state = update_stmt._compile_state_factory(
6404 update_stmt, self, **kw
6405 )
6406 if TYPE_CHECKING:
6407 assert isinstance(compile_state, UpdateDMLState)
6408 update_stmt = compile_state.statement # type: ignore[assignment]
6409
6410 if visiting_cte is not None:
6411 kw["visiting_cte"] = visiting_cte
6412 toplevel = False
6413 else:
6414 toplevel = not self.stack
6415
6416 if toplevel:
6417 self.isupdate = True
6418 if not self.dml_compile_state:
6419 self.dml_compile_state = compile_state
6420 if not self.compile_state:
6421 self.compile_state = compile_state
6422
6423 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
6424 from_linter = FromLinter({}, set())
6425 warn_linting = self.linting & WARN_LINTING
6426 if toplevel:
6427 self.from_linter = from_linter
6428 else:
6429 from_linter = None
6430 warn_linting = False
6431
6432 extra_froms = compile_state._extra_froms
6433 is_multitable = bool(extra_froms)
6434
6435 if is_multitable:
6436 # main table might be a JOIN
6437 main_froms = set(_from_objects(update_stmt.table))
6438 render_extra_froms = [
6439 f for f in extra_froms if f not in main_froms
6440 ]
6441 correlate_froms = main_froms.union(extra_froms)
6442 else:
6443 render_extra_froms = []
6444 correlate_froms = {update_stmt.table}
6445
6446 self.stack.append(
6447 {
6448 "correlate_froms": correlate_froms,
6449 "asfrom_froms": correlate_froms,
6450 "selectable": update_stmt,
6451 }
6452 )
6453
6454 text = "UPDATE "
6455
6456 if update_stmt._prefixes:
6457 text += self._generate_prefixes(
6458 update_stmt, update_stmt._prefixes, **kw
6459 )
6460
6461 table_text = self.update_tables_clause(
6462 update_stmt,
6463 update_stmt.table,
6464 render_extra_froms,
6465 from_linter=from_linter,
6466 **kw,
6467 )
6468 crud_params_struct = crud._get_crud_params(
6469 self, update_stmt, compile_state, toplevel, **kw
6470 )
6471 crud_params = crud_params_struct.single_params
6472
6473 if update_stmt._hints:
6474 dialect_hints, table_text = self._setup_crud_hints(
6475 update_stmt, table_text
6476 )
6477 else:
6478 dialect_hints = None
6479
6480 if update_stmt._independent_ctes:
6481 self._dispatch_independent_ctes(update_stmt, kw)
6482
6483 text += table_text
6484
6485 text += " SET "
6486 text += ", ".join(
6487 expr + "=" + value
6488 for _, expr, value, _ in cast(
6489 "List[Tuple[Any, str, str, Any]]", crud_params
6490 )
6491 )
6492
6493 if self.implicit_returning or update_stmt._returning:
6494 if self.returning_precedes_values:
6495 text += " " + self.returning_clause(
6496 update_stmt,
6497 self.implicit_returning or update_stmt._returning,
6498 populate_result_map=toplevel,
6499 )
6500
6501 if extra_froms:
6502 extra_from_text = self.update_from_clause(
6503 update_stmt,
6504 update_stmt.table,
6505 render_extra_froms,
6506 dialect_hints,
6507 from_linter=from_linter,
6508 **kw,
6509 )
6510 if extra_from_text:
6511 text += " " + extra_from_text
6512
6513 if update_stmt._where_criteria:
6514 t = self._generate_delimited_and_list(
6515 update_stmt._where_criteria, from_linter=from_linter, **kw
6516 )
6517 if t:
6518 text += " WHERE " + t
6519
6520 ulc = self.update_post_criteria_clause(
6521 update_stmt, from_linter=from_linter, **kw
6522 )
6523 if ulc:
6524 text += " " + ulc
6525
6526 if (
6527 self.implicit_returning or update_stmt._returning
6528 ) and not self.returning_precedes_values:
6529 text += " " + self.returning_clause(
6530 update_stmt,
6531 self.implicit_returning or update_stmt._returning,
6532 populate_result_map=toplevel,
6533 )
6534
6535 if self.ctes:
6536 nesting_level = len(self.stack) if not toplevel else None
6537 text = self._render_cte_clause(nesting_level=nesting_level) + text
6538
6539 if warn_linting:
6540 assert from_linter is not None
6541 from_linter.warn(stmt_type="UPDATE")
6542
6543 self.stack.pop(-1)
6544
6545 return text # type: ignore[no-any-return]
6546
6547 def delete_extra_from_clause(
6548 self, delete_stmt, from_table, extra_froms, from_hints, **kw
6549 ):
6550 """Provide a hook to override the generation of an
6551 DELETE..FROM clause.
6552
6553 This can be used to implement DELETE..USING for example.
6554
6555 MySQL and MSSQL override this.
6556
6557 """
6558 raise NotImplementedError(
6559 "This backend does not support multiple-table "
6560 "criteria within DELETE"
6561 )
6562
6563 def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
6564 return from_table._compiler_dispatch(
6565 self, asfrom=True, iscrud=True, **kw
6566 )
6567
6568 def visit_delete(self, delete_stmt, visiting_cte=None, **kw):
6569 compile_state = delete_stmt._compile_state_factory(
6570 delete_stmt, self, **kw
6571 )
6572 delete_stmt = compile_state.statement
6573
6574 if visiting_cte is not None:
6575 kw["visiting_cte"] = visiting_cte
6576 toplevel = False
6577 else:
6578 toplevel = not self.stack
6579
6580 if toplevel:
6581 self.isdelete = True
6582 if not self.dml_compile_state:
6583 self.dml_compile_state = compile_state
6584 if not self.compile_state:
6585 self.compile_state = compile_state
6586
6587 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
6588 from_linter = FromLinter({}, set())
6589 warn_linting = self.linting & WARN_LINTING
6590 if toplevel:
6591 self.from_linter = from_linter
6592 else:
6593 from_linter = None
6594 warn_linting = False
6595
6596 extra_froms = compile_state._extra_froms
6597
6598 correlate_froms = {delete_stmt.table}.union(extra_froms)
6599 self.stack.append(
6600 {
6601 "correlate_froms": correlate_froms,
6602 "asfrom_froms": correlate_froms,
6603 "selectable": delete_stmt,
6604 }
6605 )
6606
6607 text = "DELETE "
6608
6609 if delete_stmt._prefixes:
6610 text += self._generate_prefixes(
6611 delete_stmt, delete_stmt._prefixes, **kw
6612 )
6613
6614 text += "FROM "
6615
6616 try:
6617 table_text = self.delete_table_clause(
6618 delete_stmt,
6619 delete_stmt.table,
6620 extra_froms,
6621 from_linter=from_linter,
6622 )
6623 except TypeError:
6624 # anticipate 3rd party dialects that don't include **kw
6625 # TODO: remove in 2.1
6626 table_text = self.delete_table_clause(
6627 delete_stmt, delete_stmt.table, extra_froms
6628 )
6629 if from_linter:
6630 _ = self.process(delete_stmt.table, from_linter=from_linter)
6631
6632 crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
6633
6634 if delete_stmt._hints:
6635 dialect_hints, table_text = self._setup_crud_hints(
6636 delete_stmt, table_text
6637 )
6638 else:
6639 dialect_hints = None
6640
6641 if delete_stmt._independent_ctes:
6642 self._dispatch_independent_ctes(delete_stmt, kw)
6643
6644 text += table_text
6645
6646 if (
6647 self.implicit_returning or delete_stmt._returning
6648 ) and self.returning_precedes_values:
6649 text += " " + self.returning_clause(
6650 delete_stmt,
6651 self.implicit_returning or delete_stmt._returning,
6652 populate_result_map=toplevel,
6653 )
6654
6655 if extra_froms:
6656 extra_from_text = self.delete_extra_from_clause(
6657 delete_stmt,
6658 delete_stmt.table,
6659 extra_froms,
6660 dialect_hints,
6661 from_linter=from_linter,
6662 **kw,
6663 )
6664 if extra_from_text:
6665 text += " " + extra_from_text
6666
6667 if delete_stmt._where_criteria:
6668 t = self._generate_delimited_and_list(
6669 delete_stmt._where_criteria, from_linter=from_linter, **kw
6670 )
6671 if t:
6672 text += " WHERE " + t
6673
6674 dlc = self.delete_post_criteria_clause(
6675 delete_stmt, from_linter=from_linter, **kw
6676 )
6677 if dlc:
6678 text += " " + dlc
6679
6680 if (
6681 self.implicit_returning or delete_stmt._returning
6682 ) and not self.returning_precedes_values:
6683 text += " " + self.returning_clause(
6684 delete_stmt,
6685 self.implicit_returning or delete_stmt._returning,
6686 populate_result_map=toplevel,
6687 )
6688
6689 if self.ctes:
6690 nesting_level = len(self.stack) if not toplevel else None
6691 text = self._render_cte_clause(nesting_level=nesting_level) + text
6692
6693 if warn_linting:
6694 assert from_linter is not None
6695 from_linter.warn(stmt_type="DELETE")
6696
6697 self.stack.pop(-1)
6698
6699 return text
6700
6701 def visit_savepoint(self, savepoint_stmt, **kw):
6702 return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
6703
6704 def visit_rollback_to_savepoint(self, savepoint_stmt, **kw):
6705 return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
6706 savepoint_stmt
6707 )
6708
6709 def visit_release_savepoint(self, savepoint_stmt, **kw):
6710 return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
6711 savepoint_stmt
6712 )
6713
6714
6715class StrSQLCompiler(SQLCompiler):
6716 """A :class:`.SQLCompiler` subclass which allows a small selection
6717 of non-standard SQL features to render into a string value.
6718
6719 The :class:`.StrSQLCompiler` is invoked whenever a Core expression
6720 element is directly stringified without calling upon the
6721 :meth:`_expression.ClauseElement.compile` method.
6722 It can render a limited set
6723 of non-standard SQL constructs to assist in basic stringification,
6724 however for more substantial custom or dialect-specific SQL constructs,
6725 it will be necessary to make use of
6726 :meth:`_expression.ClauseElement.compile`
6727 directly.
6728
6729 .. seealso::
6730
6731 :ref:`faq_sql_expression_string`
6732
6733 """
6734
6735 def _fallback_column_name(self, column):
6736 return "<name unknown>"
6737
6738 @util.preload_module("sqlalchemy.engine.url")
6739 def visit_unsupported_compilation(self, element, err, **kw):
6740 if element.stringify_dialect != "default":
6741 url = util.preloaded.engine_url
6742 dialect = url.URL.create(element.stringify_dialect).get_dialect()()
6743
6744 compiler = dialect.statement_compiler(
6745 dialect, None, _supporting_against=self
6746 )
6747 if not isinstance(compiler, StrSQLCompiler):
6748 return compiler.process(element, **kw)
6749
6750 return super().visit_unsupported_compilation(element, err)
6751
6752 def visit_getitem_binary(self, binary, operator, **kw):
6753 return "%s[%s]" % (
6754 self.process(binary.left, **kw),
6755 self.process(binary.right, **kw),
6756 )
6757
6758 def visit_json_getitem_op_binary(self, binary, operator, **kw):
6759 return self.visit_getitem_binary(binary, operator, **kw)
6760
6761 def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
6762 return self.visit_getitem_binary(binary, operator, **kw)
6763
6764 def visit_sequence(self, sequence, **kw):
6765 return (
6766 f"<next sequence value: {self.preparer.format_sequence(sequence)}>"
6767 )
6768
6769 def returning_clause(
6770 self,
6771 stmt: UpdateBase,
6772 returning_cols: Sequence[_ColumnsClauseElement],
6773 *,
6774 populate_result_map: bool,
6775 **kw: Any,
6776 ) -> str:
6777 columns = [
6778 self._label_select_column(None, c, True, False, {})
6779 for c in base._select_iterables(returning_cols)
6780 ]
6781 return "RETURNING " + ", ".join(columns)
6782
6783 def update_from_clause(
6784 self, update_stmt, from_table, extra_froms, from_hints, **kw
6785 ):
6786 kw["asfrom"] = True
6787 return "FROM " + ", ".join(
6788 t._compiler_dispatch(self, fromhints=from_hints, **kw)
6789 for t in extra_froms
6790 )
6791
6792 def delete_extra_from_clause(
6793 self, delete_stmt, from_table, extra_froms, from_hints, **kw
6794 ):
6795 kw["asfrom"] = True
6796 return ", " + ", ".join(
6797 t._compiler_dispatch(self, fromhints=from_hints, **kw)
6798 for t in extra_froms
6799 )
6800
6801 def visit_empty_set_expr(self, element_types, **kw):
6802 return "SELECT 1 WHERE 1!=1"
6803
6804 def get_from_hint_text(self, table, text):
6805 return "[%s]" % text
6806
6807 def visit_regexp_match_op_binary(self, binary, operator, **kw):
6808 return self._generate_generic_binary(binary, " <regexp> ", **kw)
6809
6810 def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
6811 return self._generate_generic_binary(binary, " <not regexp> ", **kw)
6812
6813 def visit_regexp_replace_op_binary(self, binary, operator, **kw):
6814 return "<regexp replace>(%s, %s)" % (
6815 binary.left._compiler_dispatch(self, **kw),
6816 binary.right._compiler_dispatch(self, **kw),
6817 )
6818
6819 def visit_try_cast(self, cast, **kwargs):
6820 return "TRY_CAST(%s AS %s)" % (
6821 cast.clause._compiler_dispatch(self, **kwargs),
6822 cast.typeclause._compiler_dispatch(self, **kwargs),
6823 )
6824
6825
6826class DDLCompiler(Compiled):
6827 is_ddl = True
6828
6829 if TYPE_CHECKING:
6830
6831 def __init__(
6832 self,
6833 dialect: Dialect,
6834 statement: ExecutableDDLElement,
6835 schema_translate_map: Optional[SchemaTranslateMapType] = ...,
6836 render_schema_translate: bool = ...,
6837 compile_kwargs: Mapping[str, Any] = ...,
6838 ): ...
6839
6840 @util.ro_memoized_property
6841 def sql_compiler(self) -> SQLCompiler:
6842 return self.dialect.statement_compiler(
6843 self.dialect, None, schema_translate_map=self.schema_translate_map
6844 )
6845
6846 @util.memoized_property
6847 def type_compiler(self):
6848 return self.dialect.type_compiler_instance
6849
6850 def construct_params(
6851 self,
6852 params: Optional[_CoreSingleExecuteParams] = None,
6853 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
6854 escape_names: bool = True,
6855 ) -> Optional[_MutableCoreSingleExecuteParams]:
6856 return None
6857
6858 def visit_ddl(self, ddl, **kwargs):
6859 # table events can substitute table and schema name
6860 context = ddl.context
6861 if isinstance(ddl.target, schema.Table):
6862 context = context.copy()
6863
6864 preparer = self.preparer
6865 path = preparer.format_table_seq(ddl.target)
6866 if len(path) == 1:
6867 table, sch = path[0], ""
6868 else:
6869 table, sch = path[-1], path[0]
6870
6871 context.setdefault("table", table)
6872 context.setdefault("schema", sch)
6873 context.setdefault("fullname", preparer.format_table(ddl.target))
6874
6875 return self.sql_compiler.post_process_text(ddl.statement % context)
6876
6877 def visit_create_schema(self, create, **kw):
6878 text = "CREATE SCHEMA "
6879 if create.if_not_exists:
6880 text += "IF NOT EXISTS "
6881 return text + self.preparer.format_schema(create.element)
6882
6883 def visit_drop_schema(self, drop, **kw):
6884 text = "DROP SCHEMA "
6885 if drop.if_exists:
6886 text += "IF EXISTS "
6887 text += self.preparer.format_schema(drop.element)
6888 if drop.cascade:
6889 text += " CASCADE"
6890 return text
6891
6892 def visit_create_table(self, create, **kw):
6893 table = create.element
6894 preparer = self.preparer
6895
6896 text = "\nCREATE "
6897 if table._prefixes:
6898 text += " ".join(table._prefixes) + " "
6899
6900 text += "TABLE "
6901 if create.if_not_exists:
6902 text += "IF NOT EXISTS "
6903
6904 text += preparer.format_table(table) + " "
6905
6906 create_table_suffix = self.create_table_suffix(table)
6907 if create_table_suffix:
6908 text += create_table_suffix + " "
6909
6910 text += "("
6911
6912 separator = "\n"
6913
6914 # if only one primary key, specify it along with the column
6915 first_pk = False
6916 for create_column in create.columns:
6917 column = create_column.element
6918 try:
6919 processed = self.process(
6920 create_column, first_pk=column.primary_key and not first_pk
6921 )
6922 if processed is not None:
6923 text += separator
6924 separator = ", \n"
6925 text += "\t" + processed
6926 if column.primary_key:
6927 first_pk = True
6928 except exc.CompileError as ce:
6929 raise exc.CompileError(
6930 "(in table '%s', column '%s'): %s"
6931 % (table.description, column.name, ce.args[0])
6932 ) from ce
6933
6934 const = self.create_table_constraints(
6935 table,
6936 _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
6937 )
6938 if const:
6939 text += separator + "\t" + const
6940
6941 text += "\n)%s\n\n" % self.post_create_table(table)
6942 return text
6943
6944 def visit_create_table_as(self, element: CreateTableAs, **kw: Any) -> str:
6945 prep = self.preparer
6946
6947 inner_kw = dict(kw)
6948 inner_kw["literal_binds"] = True
6949 select_sql = self.sql_compiler.process(element.selectable, **inner_kw)
6950
6951 parts = [
6952 "CREATE",
6953 "TEMPORARY" if element.temporary else None,
6954 "TABLE",
6955 "IF NOT EXISTS" if element.if_not_exists else None,
6956 prep.format_table(element.table),
6957 "AS",
6958 select_sql,
6959 ]
6960 return " ".join(p for p in parts if p)
6961
6962 def visit_create_column(self, create, first_pk=False, **kw):
6963 column = create.element
6964
6965 if column.system:
6966 return None
6967
6968 text = self.get_column_specification(column, first_pk=first_pk)
6969 const = " ".join(
6970 self.process(constraint) for constraint in column.constraints
6971 )
6972 if const:
6973 text += " " + const
6974
6975 return text
6976
6977 def create_table_constraints(
6978 self, table, _include_foreign_key_constraints=None, **kw
6979 ):
6980 # On some DB order is significant: visit PK first, then the
6981 # other constraints (engine.ReflectionTest.testbasic failed on FB2)
6982 constraints = []
6983 if table.primary_key:
6984 constraints.append(table.primary_key)
6985
6986 all_fkcs = table.foreign_key_constraints
6987 if _include_foreign_key_constraints is not None:
6988 omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints)
6989 else:
6990 omit_fkcs = set()
6991
6992 constraints.extend(
6993 [
6994 c
6995 for c in table._sorted_constraints
6996 if c is not table.primary_key and c not in omit_fkcs
6997 ]
6998 )
6999
7000 return ", \n\t".join(
7001 p
7002 for p in (
7003 self.process(constraint)
7004 for constraint in constraints
7005 if (constraint._should_create_for_compiler(self))
7006 and (
7007 not self.dialect.supports_alter
7008 or not getattr(constraint, "use_alter", False)
7009 )
7010 )
7011 if p is not None
7012 )
7013
7014 def visit_drop_table(self, drop, **kw):
7015 text = "\nDROP TABLE "
7016 if drop.if_exists:
7017 text += "IF EXISTS "
7018 return text + self.preparer.format_table(drop.element)
7019
7020 def visit_drop_view(self, drop, **kw):
7021 return "\nDROP VIEW " + self.preparer.format_table(drop.element)
7022
7023 def _verify_index_table(self, index: Index) -> None:
7024 if index.table is None:
7025 raise exc.CompileError(
7026 "Index '%s' is not associated with any table." % index.name
7027 )
7028
7029 def visit_create_index(
7030 self, create, include_schema=False, include_table_schema=True, **kw
7031 ):
7032 index = create.element
7033 self._verify_index_table(index)
7034 preparer = self.preparer
7035 text = "CREATE "
7036 if index.unique:
7037 text += "UNIQUE "
7038 if index.name is None:
7039 raise exc.CompileError(
7040 "CREATE INDEX requires that the index have a name"
7041 )
7042
7043 text += "INDEX "
7044 if create.if_not_exists:
7045 text += "IF NOT EXISTS "
7046
7047 text += "%s ON %s (%s)" % (
7048 self._prepared_index_name(index, include_schema=include_schema),
7049 preparer.format_table(
7050 index.table, use_schema=include_table_schema
7051 ),
7052 ", ".join(
7053 self.sql_compiler.process(
7054 expr, include_table=False, literal_binds=True
7055 )
7056 for expr in index.expressions
7057 ),
7058 )
7059 return text
7060
7061 def visit_drop_index(self, drop, **kw):
7062 index = drop.element
7063
7064 if index.name is None:
7065 raise exc.CompileError(
7066 "DROP INDEX requires that the index have a name"
7067 )
7068 text = "\nDROP INDEX "
7069 if drop.if_exists:
7070 text += "IF EXISTS "
7071
7072 return text + self._prepared_index_name(index, include_schema=True)
7073
7074 def _prepared_index_name(
7075 self, index: Index, include_schema: bool = False
7076 ) -> str:
7077 if index.table is not None:
7078 effective_schema = self.preparer.schema_for_object(index.table)
7079 else:
7080 effective_schema = None
7081 if include_schema and effective_schema:
7082 schema_name = self.preparer.quote_schema(effective_schema)
7083 else:
7084 schema_name = None
7085
7086 index_name: str = self.preparer.format_index(index)
7087
7088 if schema_name:
7089 index_name = schema_name + "." + index_name
7090 return index_name
7091
7092 def visit_add_constraint(self, create, **kw):
7093 return "ALTER TABLE %s ADD %s" % (
7094 self.preparer.format_table(create.element.table),
7095 self.process(create.element),
7096 )
7097
7098 def visit_set_table_comment(self, create, **kw):
7099 return "COMMENT ON TABLE %s IS %s" % (
7100 self.preparer.format_table(create.element),
7101 self.sql_compiler.render_literal_value(
7102 create.element.comment, sqltypes.String()
7103 ),
7104 )
7105
7106 def visit_drop_table_comment(self, drop, **kw):
7107 return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
7108 drop.element
7109 )
7110
7111 def visit_set_column_comment(self, create, **kw):
7112 return "COMMENT ON COLUMN %s IS %s" % (
7113 self.preparer.format_column(
7114 create.element, use_table=True, use_schema=True
7115 ),
7116 self.sql_compiler.render_literal_value(
7117 create.element.comment, sqltypes.String()
7118 ),
7119 )
7120
7121 def visit_drop_column_comment(self, drop, **kw):
7122 return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
7123 drop.element, use_table=True
7124 )
7125
7126 def visit_set_constraint_comment(self, create, **kw):
7127 raise exc.UnsupportedCompilationError(self, type(create))
7128
7129 def visit_drop_constraint_comment(self, drop, **kw):
7130 raise exc.UnsupportedCompilationError(self, type(drop))
7131
7132 def get_identity_options(self, identity_options):
7133 text = []
7134 if identity_options.increment is not None:
7135 text.append("INCREMENT BY %d" % identity_options.increment)
7136 if identity_options.start is not None:
7137 text.append("START WITH %d" % identity_options.start)
7138 if identity_options.minvalue is not None:
7139 text.append("MINVALUE %d" % identity_options.minvalue)
7140 if identity_options.maxvalue is not None:
7141 text.append("MAXVALUE %d" % identity_options.maxvalue)
7142 if identity_options.nominvalue is not None:
7143 text.append("NO MINVALUE")
7144 if identity_options.nomaxvalue is not None:
7145 text.append("NO MAXVALUE")
7146 if identity_options.cache is not None:
7147 text.append("CACHE %d" % identity_options.cache)
7148 if identity_options.cycle is not None:
7149 text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
7150 return " ".join(text)
7151
7152 def visit_create_sequence(self, create, prefix=None, **kw):
7153 text = "CREATE SEQUENCE "
7154 if create.if_not_exists:
7155 text += "IF NOT EXISTS "
7156 text += self.preparer.format_sequence(create.element)
7157
7158 if prefix:
7159 text += prefix
7160 options = self.get_identity_options(create.element)
7161 if options:
7162 text += " " + options
7163 return text
7164
7165 def visit_drop_sequence(self, drop, **kw):
7166 text = "DROP SEQUENCE "
7167 if drop.if_exists:
7168 text += "IF EXISTS "
7169 return text + self.preparer.format_sequence(drop.element)
7170
7171 def visit_drop_constraint(self, drop, **kw):
7172 constraint = drop.element
7173 if constraint.name is not None:
7174 formatted_name = self.preparer.format_constraint(constraint)
7175 else:
7176 formatted_name = None
7177
7178 if formatted_name is None:
7179 raise exc.CompileError(
7180 "Can't emit DROP CONSTRAINT for constraint %r; "
7181 "it has no name" % drop.element
7182 )
7183 return "ALTER TABLE %s DROP CONSTRAINT %s%s%s" % (
7184 self.preparer.format_table(drop.element.table),
7185 "IF EXISTS " if drop.if_exists else "",
7186 formatted_name,
7187 " CASCADE" if drop.cascade else "",
7188 )
7189
7190 def get_column_specification(self, column, **kwargs):
7191 colspec = (
7192 self.preparer.format_column(column)
7193 + " "
7194 + self.dialect.type_compiler_instance.process(
7195 column.type, type_expression=column
7196 )
7197 )
7198 default = self.get_column_default_string(column)
7199 if default is not None:
7200 colspec += " DEFAULT " + default
7201
7202 if column.computed is not None:
7203 colspec += " " + self.process(column.computed)
7204
7205 if (
7206 column.identity is not None
7207 and self.dialect.supports_identity_columns
7208 ):
7209 colspec += " " + self.process(column.identity)
7210
7211 if not column.nullable and (
7212 not column.identity or not self.dialect.supports_identity_columns
7213 ):
7214 colspec += " NOT NULL"
7215 return colspec
7216
7217 def create_table_suffix(self, table):
7218 return ""
7219
7220 def post_create_table(self, table):
7221 return ""
7222
7223 def get_column_default_string(self, column: Column[Any]) -> Optional[str]:
7224 if isinstance(column.server_default, schema.DefaultClause):
7225 return self.render_default_string(column.server_default.arg)
7226 else:
7227 return None
7228
7229 def render_default_string(self, default: Union[Visitable, str]) -> str:
7230 if isinstance(default, str):
7231 return self.sql_compiler.render_literal_value(
7232 default, sqltypes.STRINGTYPE
7233 )
7234 else:
7235 return self.sql_compiler.process(default, literal_binds=True)
7236
7237 def visit_table_or_column_check_constraint(self, constraint, **kw):
7238 if constraint.is_column_level:
7239 return self.visit_column_check_constraint(constraint)
7240 else:
7241 return self.visit_check_constraint(constraint)
7242
7243 def visit_check_constraint(self, constraint, **kw):
7244 text = ""
7245 if constraint.name is not None:
7246 formatted_name = self.preparer.format_constraint(constraint)
7247 if formatted_name is not None:
7248 text += "CONSTRAINT %s " % formatted_name
7249 text += "CHECK (%s)" % self.sql_compiler.process(
7250 constraint.sqltext, include_table=False, literal_binds=True
7251 )
7252 text += self.define_constraint_deferrability(constraint)
7253 return text
7254
7255 def visit_column_check_constraint(self, constraint, **kw):
7256 text = ""
7257 if constraint.name is not None:
7258 formatted_name = self.preparer.format_constraint(constraint)
7259 if formatted_name is not None:
7260 text += "CONSTRAINT %s " % formatted_name
7261 text += "CHECK (%s)" % self.sql_compiler.process(
7262 constraint.sqltext, include_table=False, literal_binds=True
7263 )
7264 text += self.define_constraint_deferrability(constraint)
7265 return text
7266
7267 def visit_primary_key_constraint(
7268 self, constraint: PrimaryKeyConstraint, **kw: Any
7269 ) -> str:
7270 if len(constraint) == 0:
7271 return ""
7272 text = ""
7273 if constraint.name is not None:
7274 formatted_name = self.preparer.format_constraint(constraint)
7275 if formatted_name is not None:
7276 text += "CONSTRAINT %s " % formatted_name
7277 text += "PRIMARY KEY "
7278 text += "(%s)" % ", ".join(
7279 self.preparer.quote(c.name)
7280 for c in (
7281 constraint.columns_autoinc_first
7282 if constraint._implicit_generated
7283 else constraint.columns
7284 )
7285 )
7286 text += self.define_constraint_deferrability(constraint)
7287 return text
7288
7289 def visit_foreign_key_constraint(self, constraint, **kw):
7290 preparer = self.preparer
7291 text = ""
7292 if constraint.name is not None:
7293 formatted_name = self.preparer.format_constraint(constraint)
7294 if formatted_name is not None:
7295 text += "CONSTRAINT %s " % formatted_name
7296 remote_table = list(constraint.elements)[0].column.table
7297 text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
7298 ", ".join(
7299 preparer.quote(f.parent.name) for f in constraint.elements
7300 ),
7301 self.define_constraint_remote_table(
7302 constraint, remote_table, preparer
7303 ),
7304 ", ".join(
7305 preparer.quote(f.column.name) for f in constraint.elements
7306 ),
7307 )
7308 text += self.define_constraint_match(constraint)
7309 text += self.define_constraint_cascades(constraint)
7310 text += self.define_constraint_deferrability(constraint)
7311 return text
7312
7313 def define_constraint_remote_table(self, constraint, table, preparer):
7314 """Format the remote table clause of a CREATE CONSTRAINT clause."""
7315
7316 return preparer.format_table(table)
7317
7318 def visit_unique_constraint(
7319 self, constraint: UniqueConstraint, **kw: Any
7320 ) -> str:
7321 if len(constraint) == 0:
7322 return ""
7323 text = ""
7324 if constraint.name is not None:
7325 formatted_name = self.preparer.format_constraint(constraint)
7326 if formatted_name is not None:
7327 text += "CONSTRAINT %s " % formatted_name
7328 text += "UNIQUE %s(%s)" % (
7329 self.define_unique_constraint_distinct(constraint, **kw),
7330 ", ".join(self.preparer.quote(c.name) for c in constraint),
7331 )
7332 text += self.define_constraint_deferrability(constraint)
7333 return text
7334
7335 def define_unique_constraint_distinct(
7336 self, constraint: UniqueConstraint, **kw: Any
7337 ) -> str:
7338 return ""
7339
7340 def define_constraint_cascades(
7341 self, constraint: ForeignKeyConstraint
7342 ) -> str:
7343 text = ""
7344 if constraint.ondelete is not None:
7345 text += self.define_constraint_ondelete_cascade(constraint)
7346
7347 if constraint.onupdate is not None:
7348 text += self.define_constraint_onupdate_cascade(constraint)
7349 return text
7350
7351 def define_constraint_ondelete_cascade(
7352 self, constraint: ForeignKeyConstraint
7353 ) -> str:
7354 return " ON DELETE %s" % self.preparer.validate_sql_phrase(
7355 constraint.ondelete, FK_ON_DELETE
7356 )
7357
7358 def define_constraint_onupdate_cascade(
7359 self, constraint: ForeignKeyConstraint
7360 ) -> str:
7361 return " ON UPDATE %s" % self.preparer.validate_sql_phrase(
7362 constraint.onupdate, FK_ON_UPDATE
7363 )
7364
7365 def define_constraint_deferrability(self, constraint: Constraint) -> str:
7366 text = ""
7367 if constraint.deferrable is not None:
7368 if constraint.deferrable:
7369 text += " DEFERRABLE"
7370 else:
7371 text += " NOT DEFERRABLE"
7372 if constraint.initially is not None:
7373 text += " INITIALLY %s" % self.preparer.validate_sql_phrase(
7374 constraint.initially, FK_INITIALLY
7375 )
7376 return text
7377
7378 def define_constraint_match(self, constraint):
7379 text = ""
7380 if constraint.match is not None:
7381 text += " MATCH %s" % constraint.match
7382 return text
7383
7384 def visit_computed_column(self, generated, **kw):
7385 text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
7386 generated.sqltext, include_table=False, literal_binds=True
7387 )
7388 if generated.persisted is True:
7389 text += " STORED"
7390 elif generated.persisted is False:
7391 text += " VIRTUAL"
7392 return text
7393
7394 def visit_identity_column(self, identity, **kw):
7395 text = "GENERATED %s AS IDENTITY" % (
7396 "ALWAYS" if identity.always else "BY DEFAULT",
7397 )
7398 options = self.get_identity_options(identity)
7399 if options:
7400 text += " (%s)" % options
7401 return text
7402
7403
7404class GenericTypeCompiler(TypeCompiler):
7405 def visit_FLOAT(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
7406 return "FLOAT"
7407
7408 def visit_DOUBLE(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
7409 return "DOUBLE"
7410
7411 def visit_DOUBLE_PRECISION(
7412 self, type_: sqltypes.DOUBLE_PRECISION[Any], **kw: Any
7413 ) -> str:
7414 return "DOUBLE PRECISION"
7415
7416 def visit_REAL(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
7417 return "REAL"
7418
7419 def visit_NUMERIC(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
7420 if type_.precision is None:
7421 return "NUMERIC"
7422 elif type_.scale is None:
7423 return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
7424 else:
7425 return "NUMERIC(%(precision)s, %(scale)s)" % {
7426 "precision": type_.precision,
7427 "scale": type_.scale,
7428 }
7429
7430 def visit_DECIMAL(self, type_: sqltypes.DECIMAL[Any], **kw: Any) -> str:
7431 if type_.precision is None:
7432 return "DECIMAL"
7433 elif type_.scale is None:
7434 return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
7435 else:
7436 return "DECIMAL(%(precision)s, %(scale)s)" % {
7437 "precision": type_.precision,
7438 "scale": type_.scale,
7439 }
7440
7441 def visit_INTEGER(self, type_: sqltypes.Integer, **kw: Any) -> str:
7442 return "INTEGER"
7443
7444 def visit_SMALLINT(self, type_: sqltypes.SmallInteger, **kw: Any) -> str:
7445 return "SMALLINT"
7446
7447 def visit_BIGINT(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
7448 return "BIGINT"
7449
7450 def visit_TIMESTAMP(self, type_: sqltypes.TIMESTAMP, **kw: Any) -> str:
7451 return "TIMESTAMP"
7452
7453 def visit_DATETIME(self, type_: sqltypes.DateTime, **kw: Any) -> str:
7454 return "DATETIME"
7455
7456 def visit_DATE(self, type_: sqltypes.Date, **kw: Any) -> str:
7457 return "DATE"
7458
7459 def visit_TIME(self, type_: sqltypes.Time, **kw: Any) -> str:
7460 return "TIME"
7461
7462 def visit_CLOB(self, type_: sqltypes.CLOB, **kw: Any) -> str:
7463 return "CLOB"
7464
7465 def visit_NCLOB(self, type_: sqltypes.Text, **kw: Any) -> str:
7466 return "NCLOB"
7467
7468 def _render_string_type(
7469 self, name: str, length: Optional[int], collation: Optional[str]
7470 ) -> str:
7471 text = name
7472 if length:
7473 text += f"({length})"
7474 if collation:
7475 text += f' COLLATE "{collation}"'
7476 return text
7477
7478 def visit_CHAR(self, type_: sqltypes.CHAR, **kw: Any) -> str:
7479 return self._render_string_type("CHAR", type_.length, type_.collation)
7480
7481 def visit_NCHAR(self, type_: sqltypes.NCHAR, **kw: Any) -> str:
7482 return self._render_string_type("NCHAR", type_.length, type_.collation)
7483
7484 def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
7485 return self._render_string_type(
7486 "VARCHAR", type_.length, type_.collation
7487 )
7488
7489 def visit_NVARCHAR(self, type_: sqltypes.NVARCHAR, **kw: Any) -> str:
7490 return self._render_string_type(
7491 "NVARCHAR", type_.length, type_.collation
7492 )
7493
7494 def visit_TEXT(self, type_: sqltypes.Text, **kw: Any) -> str:
7495 return self._render_string_type("TEXT", type_.length, type_.collation)
7496
7497 def visit_UUID(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
7498 return "UUID"
7499
7500 def visit_BLOB(self, type_: sqltypes.LargeBinary, **kw: Any) -> str:
7501 return "BLOB"
7502
7503 def visit_BINARY(self, type_: sqltypes.BINARY, **kw: Any) -> str:
7504 return "BINARY" + (type_.length and "(%d)" % type_.length or "")
7505
7506 def visit_VARBINARY(self, type_: sqltypes.VARBINARY, **kw: Any) -> str:
7507 return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
7508
7509 def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str:
7510 return "BOOLEAN"
7511
7512 def visit_uuid(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
7513 if not type_.native_uuid or not self.dialect.supports_native_uuid:
7514 return self._render_string_type("CHAR", length=32, collation=None)
7515 else:
7516 return self.visit_UUID(type_, **kw)
7517
7518 def visit_large_binary(
7519 self, type_: sqltypes.LargeBinary, **kw: Any
7520 ) -> str:
7521 return self.visit_BLOB(type_, **kw)
7522
7523 def visit_boolean(self, type_: sqltypes.Boolean, **kw: Any) -> str:
7524 return self.visit_BOOLEAN(type_, **kw)
7525
7526 def visit_time(self, type_: sqltypes.Time, **kw: Any) -> str:
7527 return self.visit_TIME(type_, **kw)
7528
7529 def visit_datetime(self, type_: sqltypes.DateTime, **kw: Any) -> str:
7530 return self.visit_DATETIME(type_, **kw)
7531
7532 def visit_date(self, type_: sqltypes.Date, **kw: Any) -> str:
7533 return self.visit_DATE(type_, **kw)
7534
7535 def visit_big_integer(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
7536 return self.visit_BIGINT(type_, **kw)
7537
7538 def visit_small_integer(
7539 self, type_: sqltypes.SmallInteger, **kw: Any
7540 ) -> str:
7541 return self.visit_SMALLINT(type_, **kw)
7542
7543 def visit_integer(self, type_: sqltypes.Integer, **kw: Any) -> str:
7544 return self.visit_INTEGER(type_, **kw)
7545
7546 def visit_real(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
7547 return self.visit_REAL(type_, **kw)
7548
7549 def visit_float(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
7550 return self.visit_FLOAT(type_, **kw)
7551
7552 def visit_double(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
7553 return self.visit_DOUBLE(type_, **kw)
7554
7555 def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
7556 return self.visit_NUMERIC(type_, **kw)
7557
7558 def visit_string(self, type_: sqltypes.String, **kw: Any) -> str:
7559 return self.visit_VARCHAR(type_, **kw)
7560
7561 def visit_unicode(self, type_: sqltypes.Unicode, **kw: Any) -> str:
7562 return self.visit_VARCHAR(type_, **kw)
7563
7564 def visit_text(self, type_: sqltypes.Text, **kw: Any) -> str:
7565 return self.visit_TEXT(type_, **kw)
7566
7567 def visit_unicode_text(
7568 self, type_: sqltypes.UnicodeText, **kw: Any
7569 ) -> str:
7570 return self.visit_TEXT(type_, **kw)
7571
7572 def visit_enum(self, type_: sqltypes.Enum, **kw: Any) -> str:
7573 return self.visit_VARCHAR(type_, **kw)
7574
7575 def visit_null(self, type_, **kw):
7576 raise exc.CompileError(
7577 "Can't generate DDL for %r; "
7578 "did you forget to specify a "
7579 "type on this Column?" % type_
7580 )
7581
7582 def visit_type_decorator(
7583 self, type_: TypeDecorator[Any], **kw: Any
7584 ) -> str:
7585 return self.process(type_.type_engine(self.dialect), **kw)
7586
7587 def visit_user_defined(
7588 self, type_: UserDefinedType[Any], **kw: Any
7589 ) -> str:
7590 return type_.get_col_spec(**kw)
7591
7592
7593class StrSQLTypeCompiler(GenericTypeCompiler):
7594 def process(self, type_, **kw):
7595 try:
7596 _compiler_dispatch = type_._compiler_dispatch
7597 except AttributeError:
7598 return self._visit_unknown(type_, **kw)
7599 else:
7600 return _compiler_dispatch(self, **kw)
7601
7602 def __getattr__(self, key):
7603 if key.startswith("visit_"):
7604 return self._visit_unknown
7605 else:
7606 raise AttributeError(key)
7607
7608 def _visit_unknown(self, type_, **kw):
7609 if type_.__class__.__name__ == type_.__class__.__name__.upper():
7610 return type_.__class__.__name__
7611 else:
7612 return repr(type_)
7613
7614 def visit_null(self, type_, **kw):
7615 return "NULL"
7616
7617 def visit_user_defined(self, type_, **kw):
7618 try:
7619 get_col_spec = type_.get_col_spec
7620 except AttributeError:
7621 return repr(type_)
7622 else:
7623 return get_col_spec(**kw)
7624
7625
7626class _SchemaForObjectCallable(Protocol):
7627 def __call__(self, obj: Any, /) -> str: ...
7628
7629
7630class _BindNameForColProtocol(Protocol):
7631 def __call__(self, col: ColumnClause[Any]) -> str: ...
7632
7633
7634class IdentifierPreparer:
7635 """Handle quoting and case-folding of identifiers based on options."""
7636
7637 reserved_words = RESERVED_WORDS
7638
7639 legal_characters = LEGAL_CHARACTERS
7640
7641 illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
7642
7643 initial_quote: str
7644
7645 final_quote: str
7646
7647 _strings: MutableMapping[str, str]
7648
7649 schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema")
7650 """Return the .schema attribute for an object.
7651
7652 For the default IdentifierPreparer, the schema for an object is always
7653 the value of the ".schema" attribute. if the preparer is replaced
7654 with one that has a non-empty schema_translate_map, the value of the
7655 ".schema" attribute is rendered a symbol that will be converted to a
7656 real schema name from the mapping post-compile.
7657
7658 """
7659
7660 _includes_none_schema_translate: bool = False
7661
7662 def __init__(
7663 self,
7664 dialect: Dialect,
7665 initial_quote: str = '"',
7666 final_quote: Optional[str] = None,
7667 escape_quote: str = '"',
7668 quote_case_sensitive_collations: bool = True,
7669 omit_schema: bool = False,
7670 ):
7671 """Construct a new ``IdentifierPreparer`` object.
7672
7673 initial_quote
7674 Character that begins a delimited identifier.
7675
7676 final_quote
7677 Character that ends a delimited identifier. Defaults to
7678 `initial_quote`.
7679
7680 omit_schema
7681 Prevent prepending schema name. Useful for databases that do
7682 not support schemae.
7683 """
7684
7685 self.dialect = dialect
7686 self.initial_quote = initial_quote
7687 self.final_quote = final_quote or self.initial_quote
7688 self.escape_quote = escape_quote
7689 self.escape_to_quote = self.escape_quote * 2
7690 self.omit_schema = omit_schema
7691 self.quote_case_sensitive_collations = quote_case_sensitive_collations
7692 self._strings = {}
7693 self._double_percents = self.dialect.paramstyle in (
7694 "format",
7695 "pyformat",
7696 )
7697
7698 def _with_schema_translate(self, schema_translate_map):
7699 prep = self.__class__.__new__(self.__class__)
7700 prep.__dict__.update(self.__dict__)
7701
7702 includes_none = None in schema_translate_map
7703
7704 def symbol_getter(obj):
7705 name = obj.schema
7706 if obj._use_schema_map and (name is not None or includes_none):
7707 if name is not None and ("[" in name or "]" in name):
7708 raise exc.CompileError(
7709 "Square bracket characters ([]) not supported "
7710 "in schema translate name '%s'" % name
7711 )
7712 return quoted_name(
7713 "__[SCHEMA_%s]" % (name or "_none"), quote=False
7714 )
7715 else:
7716 return obj.schema
7717
7718 prep.schema_for_object = symbol_getter
7719 prep._includes_none_schema_translate = includes_none
7720 return prep
7721
7722 def _render_schema_translates(
7723 self, statement: str, schema_translate_map: SchemaTranslateMapType
7724 ) -> str:
7725 d = schema_translate_map
7726 if None in d:
7727 if not self._includes_none_schema_translate:
7728 raise exc.InvalidRequestError(
7729 "schema translate map which previously did not have "
7730 "`None` present as a key now has `None` present; compiled "
7731 "statement may lack adequate placeholders. Please use "
7732 "consistent keys in successive "
7733 "schema_translate_map dictionaries."
7734 )
7735
7736 d["_none"] = d[None] # type: ignore[index]
7737
7738 def replace(m):
7739 name = m.group(2)
7740 if name in d:
7741 effective_schema = d[name]
7742 else:
7743 if name in (None, "_none"):
7744 raise exc.InvalidRequestError(
7745 "schema translate map which previously had `None` "
7746 "present as a key now no longer has it present; don't "
7747 "know how to apply schema for compiled statement. "
7748 "Please use consistent keys in successive "
7749 "schema_translate_map dictionaries."
7750 )
7751 effective_schema = name
7752
7753 if not effective_schema:
7754 effective_schema = self.dialect.default_schema_name
7755 if not effective_schema:
7756 # TODO: no coverage here
7757 raise exc.CompileError(
7758 "Dialect has no default schema name; can't "
7759 "use None as dynamic schema target."
7760 )
7761 return self.quote_schema(effective_schema)
7762
7763 return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
7764
7765 def _escape_identifier(self, value: str) -> str:
7766 """Escape an identifier.
7767
7768 Subclasses should override this to provide database-dependent
7769 escaping behavior.
7770 """
7771
7772 value = value.replace(self.escape_quote, self.escape_to_quote)
7773 if self._double_percents:
7774 value = value.replace("%", "%%")
7775 return value
7776
7777 def _unescape_identifier(self, value: str) -> str:
7778 """Canonicalize an escaped identifier.
7779
7780 Subclasses should override this to provide database-dependent
7781 unescaping behavior that reverses _escape_identifier.
7782 """
7783
7784 return value.replace(self.escape_to_quote, self.escape_quote)
7785
7786 def validate_sql_phrase(self, element, reg):
7787 """keyword sequence filter.
7788
7789 a filter for elements that are intended to represent keyword sequences,
7790 such as "INITIALLY", "INITIALLY DEFERRED", etc. no special characters
7791 should be present.
7792
7793 """
7794
7795 if element is not None and not reg.match(element):
7796 raise exc.CompileError(
7797 "Unexpected SQL phrase: %r (matching against %r)"
7798 % (element, reg.pattern)
7799 )
7800 return element
7801
7802 def quote_identifier(self, value: str) -> str:
7803 """Quote an identifier.
7804
7805 Subclasses should override this to provide database-dependent
7806 quoting behavior.
7807 """
7808
7809 return (
7810 self.initial_quote
7811 + self._escape_identifier(value)
7812 + self.final_quote
7813 )
7814
7815 def _requires_quotes(self, value: str) -> bool:
7816 """Return True if the given identifier requires quoting."""
7817 lc_value = value.lower()
7818 return (
7819 lc_value in self.reserved_words
7820 or value[0] in self.illegal_initial_characters
7821 or not self.legal_characters.match(str(value))
7822 or (lc_value != value)
7823 )
7824
7825 def _requires_quotes_illegal_chars(self, value):
7826 """Return True if the given identifier requires quoting, but
7827 not taking case convention into account."""
7828 return not self.legal_characters.match(str(value))
7829
7830 def quote_schema(self, schema: str) -> str:
7831 """Conditionally quote a schema name.
7832
7833
7834 The name is quoted if it is a reserved word, contains quote-necessary
7835 characters, or is an instance of :class:`.quoted_name` which includes
7836 ``quote`` set to ``True``.
7837
7838 Subclasses can override this to provide database-dependent
7839 quoting behavior for schema names.
7840
7841 :param schema: string schema name
7842 """
7843 return self.quote(schema)
7844
7845 def quote(self, ident: str) -> str:
7846 """Conditionally quote an identifier.
7847
7848 The identifier is quoted if it is a reserved word, contains
7849 quote-necessary characters, or is an instance of
7850 :class:`.quoted_name` which includes ``quote`` set to ``True``.
7851
7852 Subclasses can override this to provide database-dependent
7853 quoting behavior for identifier names.
7854
7855 :param ident: string identifier
7856 """
7857 force = getattr(ident, "quote", None)
7858
7859 if force is None:
7860 if ident in self._strings:
7861 return self._strings[ident]
7862 else:
7863 if self._requires_quotes(ident):
7864 self._strings[ident] = self.quote_identifier(ident)
7865 else:
7866 self._strings[ident] = ident
7867 return self._strings[ident]
7868 elif force:
7869 return self.quote_identifier(ident)
7870 else:
7871 return ident
7872
7873 def format_collation(self, collation_name):
7874 if self.quote_case_sensitive_collations:
7875 return self.quote(collation_name)
7876 else:
7877 return collation_name
7878
7879 def format_sequence(
7880 self, sequence: schema.Sequence, use_schema: bool = True
7881 ) -> str:
7882 name = self.quote(sequence.name)
7883
7884 effective_schema = self.schema_for_object(sequence)
7885
7886 if (
7887 not self.omit_schema
7888 and use_schema
7889 and effective_schema is not None
7890 ):
7891 name = self.quote_schema(effective_schema) + "." + name
7892 return name
7893
7894 def format_label(
7895 self, label: Label[Any], name: Optional[str] = None
7896 ) -> str:
7897 return self.quote(name or label.name)
7898
7899 def format_alias(
7900 self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None
7901 ) -> str:
7902 if name is None:
7903 assert alias is not None
7904 return self.quote(alias.name)
7905 else:
7906 return self.quote(name)
7907
7908 def format_savepoint(self, savepoint, name=None):
7909 # Running the savepoint name through quoting is unnecessary
7910 # for all known dialects. This is here to support potential
7911 # third party use cases
7912 ident = name or savepoint.ident
7913 if self._requires_quotes(ident):
7914 ident = self.quote_identifier(ident)
7915 return ident
7916
7917 @util.preload_module("sqlalchemy.sql.naming")
7918 def format_constraint(
7919 self, constraint: Union[Constraint, Index], _alembic_quote: bool = True
7920 ) -> Optional[str]:
7921 naming = util.preloaded.sql_naming
7922
7923 if constraint.name is _NONE_NAME:
7924 name = naming._constraint_name_for_table(
7925 constraint, constraint.table
7926 )
7927
7928 if name is None:
7929 return None
7930 else:
7931 name = constraint.name
7932
7933 assert name is not None
7934 if constraint.__visit_name__ == "index":
7935 return self.truncate_and_render_index_name(
7936 name, _alembic_quote=_alembic_quote
7937 )
7938 else:
7939 return self.truncate_and_render_constraint_name(
7940 name, _alembic_quote=_alembic_quote
7941 )
7942
7943 def truncate_and_render_index_name(
7944 self, name: str, _alembic_quote: bool = True
7945 ) -> str:
7946 # calculate these at format time so that ad-hoc changes
7947 # to dialect.max_identifier_length etc. can be reflected
7948 # as IdentifierPreparer is long lived
7949 max_ = (
7950 self.dialect.max_index_name_length
7951 or self.dialect.max_identifier_length
7952 )
7953 return self._truncate_and_render_maxlen_name(
7954 name, max_, _alembic_quote
7955 )
7956
7957 def truncate_and_render_constraint_name(
7958 self, name: str, _alembic_quote: bool = True
7959 ) -> str:
7960 # calculate these at format time so that ad-hoc changes
7961 # to dialect.max_identifier_length etc. can be reflected
7962 # as IdentifierPreparer is long lived
7963 max_ = (
7964 self.dialect.max_constraint_name_length
7965 or self.dialect.max_identifier_length
7966 )
7967 return self._truncate_and_render_maxlen_name(
7968 name, max_, _alembic_quote
7969 )
7970
7971 def _truncate_and_render_maxlen_name(
7972 self, name: str, max_: int, _alembic_quote: bool
7973 ) -> str:
7974 if isinstance(name, elements._truncated_label):
7975 if len(name) > max_:
7976 name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
7977 else:
7978 self.dialect.validate_identifier(name)
7979
7980 if not _alembic_quote:
7981 return name
7982 else:
7983 return self.quote(name)
7984
7985 def format_index(self, index: Index) -> str:
7986 name = self.format_constraint(index)
7987 assert name is not None
7988 return name
7989
7990 def format_table(
7991 self,
7992 table: FromClause,
7993 use_schema: bool = True,
7994 name: Optional[str] = None,
7995 ) -> str:
7996 """Prepare a quoted table and schema name."""
7997 if name is None:
7998 if TYPE_CHECKING:
7999 assert isinstance(table, NamedFromClause)
8000 name = table.name
8001
8002 result = self.quote(name)
8003
8004 effective_schema = self.schema_for_object(table)
8005
8006 if not self.omit_schema and use_schema and effective_schema:
8007 result = self.quote_schema(effective_schema) + "." + result
8008 return result
8009
8010 def format_schema(self, name):
8011 """Prepare a quoted schema name."""
8012
8013 return self.quote(name)
8014
8015 def format_label_name(
8016 self,
8017 name,
8018 anon_map=None,
8019 ):
8020 """Prepare a quoted column name."""
8021
8022 if anon_map is not None and isinstance(
8023 name, elements._truncated_label
8024 ):
8025 name = name.apply_map(anon_map)
8026
8027 return self.quote(name)
8028
8029 def format_column(
8030 self,
8031 column: ColumnElement[Any],
8032 use_table: bool = False,
8033 name: Optional[str] = None,
8034 table_name: Optional[str] = None,
8035 use_schema: bool = False,
8036 anon_map: Optional[Mapping[str, Any]] = None,
8037 ) -> str:
8038 """Prepare a quoted column name."""
8039
8040 if name is None:
8041 name = column.name
8042 assert name is not None
8043
8044 if anon_map is not None and isinstance(
8045 name, elements._truncated_label
8046 ):
8047 name = name.apply_map(anon_map)
8048
8049 if not getattr(column, "is_literal", False):
8050 if use_table:
8051 return (
8052 self.format_table(
8053 column.table, use_schema=use_schema, name=table_name
8054 )
8055 + "."
8056 + self.quote(name)
8057 )
8058 else:
8059 return self.quote(name)
8060 else:
8061 # literal textual elements get stuck into ColumnClause a lot,
8062 # which shouldn't get quoted
8063
8064 if use_table:
8065 return (
8066 self.format_table(
8067 column.table, use_schema=use_schema, name=table_name
8068 )
8069 + "."
8070 + name
8071 )
8072 else:
8073 return name
8074
8075 def format_table_seq(self, table, use_schema=True):
8076 """Format table name and schema as a tuple."""
8077
8078 # Dialects with more levels in their fully qualified references
8079 # ('database', 'owner', etc.) could override this and return
8080 # a longer sequence.
8081
8082 effective_schema = self.schema_for_object(table)
8083
8084 if not self.omit_schema and use_schema and effective_schema:
8085 return (
8086 self.quote_schema(effective_schema),
8087 self.format_table(table, use_schema=False),
8088 )
8089 else:
8090 return (self.format_table(table, use_schema=False),)
8091
8092 @util.memoized_property
8093 def _r_identifiers(self):
8094 initial, final, escaped_final = (
8095 re.escape(s)
8096 for s in (
8097 self.initial_quote,
8098 self.final_quote,
8099 self._escape_identifier(self.final_quote),
8100 )
8101 )
8102 r = re.compile(
8103 r"(?:"
8104 r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
8105 r"|([^\.]+))(?=\.|$))+"
8106 % {"initial": initial, "final": final, "escaped": escaped_final}
8107 )
8108 return r
8109
8110 def unformat_identifiers(self, identifiers: str) -> Sequence[str]:
8111 """Unpack 'schema.table.column'-like strings into components."""
8112
8113 r = self._r_identifiers
8114 return [
8115 self._unescape_identifier(i)
8116 for i in [a or b for a, b in r.findall(identifiers)]
8117 ]