1# sql/compiler.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Base SQL and DDL compiler implementations.
10
11Classes provided include:
12
13:class:`.compiler.SQLCompiler` - renders SQL
14strings
15
16:class:`.compiler.DDLCompiler` - renders DDL
17(data definition language) strings
18
19:class:`.compiler.GenericTypeCompiler` - renders
20type specification strings.
21
22To generate user-defined SQL strings, see
23:doc:`/ext/compiler`.
24
25"""
26from __future__ import annotations
27
28import collections
29import collections.abc as collections_abc
30import contextlib
31from enum import IntEnum
32import functools
33import itertools
34import operator
35import re
36from time import perf_counter
37import typing
38from typing import Any
39from typing import Callable
40from typing import cast
41from typing import ClassVar
42from typing import Dict
43from typing import FrozenSet
44from typing import Iterable
45from typing import Iterator
46from typing import List
47from typing import Mapping
48from typing import MutableMapping
49from typing import NamedTuple
50from typing import NoReturn
51from typing import Optional
52from typing import Pattern
53from typing import Sequence
54from typing import Set
55from typing import Tuple
56from typing import Type
57from typing import TYPE_CHECKING
58from typing import Union
59
60from . import base
61from . import coercions
62from . import crud
63from . import elements
64from . import functions
65from . import operators
66from . import roles
67from . import schema
68from . import selectable
69from . import sqltypes
70from . import util as sql_util
71from ._typing import is_column_element
72from ._typing import is_dml
73from .base import _de_clone
74from .base import _from_objects
75from .base import _NONE_NAME
76from .base import _SentinelDefaultCharacterization
77from .base import NO_ARG
78from .elements import quoted_name
79from .sqltypes import TupleType
80from .visitors import prefix_anon_map
81from .. import exc
82from .. import util
83from ..util import FastIntFlag
84from ..util.typing import Literal
85from ..util.typing import Protocol
86from ..util.typing import Self
87from ..util.typing import TypedDict
88
89if typing.TYPE_CHECKING:
90 from .annotation import _AnnotationDict
91 from .base import _AmbiguousTableNameMap
92 from .base import CompileState
93 from .base import Executable
94 from .cache_key import CacheKey
95 from .ddl import ExecutableDDLElement
96 from .dml import Insert
97 from .dml import Update
98 from .dml import UpdateBase
99 from .dml import UpdateDMLState
100 from .dml import ValuesBase
101 from .elements import _truncated_label
102 from .elements import BinaryExpression
103 from .elements import BindParameter
104 from .elements import ClauseElement
105 from .elements import ColumnClause
106 from .elements import ColumnElement
107 from .elements import False_
108 from .elements import Label
109 from .elements import Null
110 from .elements import True_
111 from .functions import Function
112 from .schema import Column
113 from .schema import Constraint
114 from .schema import ForeignKeyConstraint
115 from .schema import Index
116 from .schema import PrimaryKeyConstraint
117 from .schema import Table
118 from .schema import UniqueConstraint
119 from .selectable import _ColumnsClauseElement
120 from .selectable import AliasedReturnsRows
121 from .selectable import CompoundSelectState
122 from .selectable import CTE
123 from .selectable import FromClause
124 from .selectable import NamedFromClause
125 from .selectable import ReturnsRows
126 from .selectable import Select
127 from .selectable import SelectState
128 from .type_api import _BindProcessorType
129 from .type_api import TypeDecorator
130 from .type_api import TypeEngine
131 from .type_api import UserDefinedType
132 from .visitors import Visitable
133 from ..engine.cursor import CursorResultMetaData
134 from ..engine.interfaces import _CoreSingleExecuteParams
135 from ..engine.interfaces import _DBAPIAnyExecuteParams
136 from ..engine.interfaces import _DBAPIMultiExecuteParams
137 from ..engine.interfaces import _DBAPISingleExecuteParams
138 from ..engine.interfaces import _ExecuteOptions
139 from ..engine.interfaces import _GenericSetInputSizesType
140 from ..engine.interfaces import _MutableCoreSingleExecuteParams
141 from ..engine.interfaces import Dialect
142 from ..engine.interfaces import SchemaTranslateMapType
143
144
145_FromHintsType = Dict["FromClause", str]
146
147RESERVED_WORDS = {
148 "all",
149 "analyse",
150 "analyze",
151 "and",
152 "any",
153 "array",
154 "as",
155 "asc",
156 "asymmetric",
157 "authorization",
158 "between",
159 "binary",
160 "both",
161 "case",
162 "cast",
163 "check",
164 "collate",
165 "column",
166 "constraint",
167 "create",
168 "cross",
169 "current_date",
170 "current_role",
171 "current_time",
172 "current_timestamp",
173 "current_user",
174 "default",
175 "deferrable",
176 "desc",
177 "distinct",
178 "do",
179 "else",
180 "end",
181 "except",
182 "false",
183 "for",
184 "foreign",
185 "freeze",
186 "from",
187 "full",
188 "grant",
189 "group",
190 "having",
191 "ilike",
192 "in",
193 "initially",
194 "inner",
195 "intersect",
196 "into",
197 "is",
198 "isnull",
199 "join",
200 "leading",
201 "left",
202 "like",
203 "limit",
204 "localtime",
205 "localtimestamp",
206 "natural",
207 "new",
208 "not",
209 "notnull",
210 "null",
211 "off",
212 "offset",
213 "old",
214 "on",
215 "only",
216 "or",
217 "order",
218 "outer",
219 "overlaps",
220 "placing",
221 "primary",
222 "references",
223 "right",
224 "select",
225 "session_user",
226 "set",
227 "similar",
228 "some",
229 "symmetric",
230 "table",
231 "then",
232 "to",
233 "trailing",
234 "true",
235 "union",
236 "unique",
237 "user",
238 "using",
239 "verbose",
240 "when",
241 "where",
242}
243
244LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
245LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
246ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
247
248FK_ON_DELETE = re.compile(
249 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
250)
251FK_ON_UPDATE = re.compile(
252 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
253)
254FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
255BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
256BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
257
258_pyformat_template = "%%(%(name)s)s"
259BIND_TEMPLATES = {
260 "pyformat": _pyformat_template,
261 "qmark": "?",
262 "format": "%%s",
263 "numeric": ":[_POSITION]",
264 "numeric_dollar": "$[_POSITION]",
265 "named": ":%(name)s",
266}
267
268
269OPERATORS = {
270 # binary
271 operators.and_: " AND ",
272 operators.or_: " OR ",
273 operators.add: " + ",
274 operators.mul: " * ",
275 operators.sub: " - ",
276 operators.mod: " % ",
277 operators.neg: "-",
278 operators.lt: " < ",
279 operators.le: " <= ",
280 operators.ne: " != ",
281 operators.gt: " > ",
282 operators.ge: " >= ",
283 operators.eq: " = ",
284 operators.is_distinct_from: " IS DISTINCT FROM ",
285 operators.is_not_distinct_from: " IS NOT DISTINCT FROM ",
286 operators.concat_op: " || ",
287 operators.match_op: " MATCH ",
288 operators.not_match_op: " NOT MATCH ",
289 operators.in_op: " IN ",
290 operators.not_in_op: " NOT IN ",
291 operators.comma_op: ", ",
292 operators.from_: " FROM ",
293 operators.as_: " AS ",
294 operators.is_: " IS ",
295 operators.is_not: " IS NOT ",
296 operators.collate: " COLLATE ",
297 # unary
298 operators.exists: "EXISTS ",
299 operators.distinct_op: "DISTINCT ",
300 operators.inv: "NOT ",
301 operators.any_op: "ANY ",
302 operators.all_op: "ALL ",
303 # modifiers
304 operators.desc_op: " DESC",
305 operators.asc_op: " ASC",
306 operators.nulls_first_op: " NULLS FIRST",
307 operators.nulls_last_op: " NULLS LAST",
308 # bitwise
309 operators.bitwise_xor_op: " ^ ",
310 operators.bitwise_or_op: " | ",
311 operators.bitwise_and_op: " & ",
312 operators.bitwise_not_op: "~",
313 operators.bitwise_lshift_op: " << ",
314 operators.bitwise_rshift_op: " >> ",
315}
316
317FUNCTIONS: Dict[Type[Function[Any]], str] = {
318 functions.coalesce: "coalesce",
319 functions.current_date: "CURRENT_DATE",
320 functions.current_time: "CURRENT_TIME",
321 functions.current_timestamp: "CURRENT_TIMESTAMP",
322 functions.current_user: "CURRENT_USER",
323 functions.localtime: "LOCALTIME",
324 functions.localtimestamp: "LOCALTIMESTAMP",
325 functions.random: "random",
326 functions.sysdate: "sysdate",
327 functions.session_user: "SESSION_USER",
328 functions.user: "USER",
329 functions.cube: "CUBE",
330 functions.rollup: "ROLLUP",
331 functions.grouping_sets: "GROUPING SETS",
332}
333
334
335EXTRACT_MAP = {
336 "month": "month",
337 "day": "day",
338 "year": "year",
339 "second": "second",
340 "hour": "hour",
341 "doy": "doy",
342 "minute": "minute",
343 "quarter": "quarter",
344 "dow": "dow",
345 "week": "week",
346 "epoch": "epoch",
347 "milliseconds": "milliseconds",
348 "microseconds": "microseconds",
349 "timezone_hour": "timezone_hour",
350 "timezone_minute": "timezone_minute",
351}
352
353COMPOUND_KEYWORDS = {
354 selectable._CompoundSelectKeyword.UNION: "UNION",
355 selectable._CompoundSelectKeyword.UNION_ALL: "UNION ALL",
356 selectable._CompoundSelectKeyword.EXCEPT: "EXCEPT",
357 selectable._CompoundSelectKeyword.EXCEPT_ALL: "EXCEPT ALL",
358 selectable._CompoundSelectKeyword.INTERSECT: "INTERSECT",
359 selectable._CompoundSelectKeyword.INTERSECT_ALL: "INTERSECT ALL",
360}
361
362
363class ResultColumnsEntry(NamedTuple):
364 """Tracks a column expression that is expected to be represented
365 in the result rows for this statement.
366
367 This normally refers to the columns clause of a SELECT statement
368 but may also refer to a RETURNING clause, as well as for dialect-specific
369 emulations.
370
371 """
372
373 keyname: str
374 """string name that's expected in cursor.description"""
375
376 name: str
377 """column name, may be labeled"""
378
379 objects: Tuple[Any, ...]
380 """sequence of objects that should be able to locate this column
381 in a RowMapping. This is typically string names and aliases
382 as well as Column objects.
383
384 """
385
386 type: TypeEngine[Any]
387 """Datatype to be associated with this column. This is where
388 the "result processing" logic directly links the compiled statement
389 to the rows that come back from the cursor.
390
391 """
392
393
394class _ResultMapAppender(Protocol):
395 def __call__(
396 self,
397 keyname: str,
398 name: str,
399 objects: Sequence[Any],
400 type_: TypeEngine[Any],
401 ) -> None: ...
402
403
404# integer indexes into ResultColumnsEntry used by cursor.py.
405# some profiling showed integer access faster than named tuple
406RM_RENDERED_NAME: Literal[0] = 0
407RM_NAME: Literal[1] = 1
408RM_OBJECTS: Literal[2] = 2
409RM_TYPE: Literal[3] = 3
410
411
412class _BaseCompilerStackEntry(TypedDict):
413 asfrom_froms: Set[FromClause]
414 correlate_froms: Set[FromClause]
415 selectable: ReturnsRows
416
417
418class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
419 compile_state: CompileState
420 need_result_map_for_nested: bool
421 need_result_map_for_compound: bool
422 select_0: ReturnsRows
423 insert_from_select: Select[Any]
424
425
426class ExpandedState(NamedTuple):
427 """represents state to use when producing "expanded" and
428 "post compile" bound parameters for a statement.
429
430 "expanded" parameters are parameters that are generated at
431 statement execution time to suit a number of parameters passed, the most
432 prominent example being the individual elements inside of an IN expression.
433
434 "post compile" parameters are parameters where the SQL literal value
435 will be rendered into the SQL statement at execution time, rather than
436 being passed as separate parameters to the driver.
437
438 To create an :class:`.ExpandedState` instance, use the
439 :meth:`.SQLCompiler.construct_expanded_state` method on any
440 :class:`.SQLCompiler` instance.
441
442 """
443
444 statement: str
445 """String SQL statement with parameters fully expanded"""
446
447 parameters: _CoreSingleExecuteParams
448 """Parameter dictionary with parameters fully expanded.
449
450 For a statement that uses named parameters, this dictionary will map
451 exactly to the names in the statement. For a statement that uses
452 positional parameters, the :attr:`.ExpandedState.positional_parameters`
453 will yield a tuple with the positional parameter set.
454
455 """
456
457 processors: Mapping[str, _BindProcessorType[Any]]
458 """mapping of bound value processors"""
459
460 positiontup: Optional[Sequence[str]]
461 """Sequence of string names indicating the order of positional
462 parameters"""
463
464 parameter_expansion: Mapping[str, List[str]]
465 """Mapping representing the intermediary link from original parameter
466 name to list of "expanded" parameter names, for those parameters that
467 were expanded."""
468
469 @property
470 def positional_parameters(self) -> Tuple[Any, ...]:
471 """Tuple of positional parameters, for statements that were compiled
472 using a positional paramstyle.
473
474 """
475 if self.positiontup is None:
476 raise exc.InvalidRequestError(
477 "statement does not use a positional paramstyle"
478 )
479 return tuple(self.parameters[key] for key in self.positiontup)
480
481 @property
482 def additional_parameters(self) -> _CoreSingleExecuteParams:
483 """synonym for :attr:`.ExpandedState.parameters`."""
484 return self.parameters
485
486
487class _InsertManyValues(NamedTuple):
488 """represents state to use for executing an "insertmanyvalues" statement.
489
490 The primary consumers of this object are the
491 :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and
492 :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods.
493
494 .. versionadded:: 2.0
495
496 """
497
498 is_default_expr: bool
499 """if True, the statement is of the form
500 ``INSERT INTO TABLE DEFAULT VALUES``, and can't be rewritten as a "batch"
501
502 """
503
504 single_values_expr: str
505 """The rendered "values" clause of the INSERT statement.
506
507 This is typically the parenthesized section e.g. "(?, ?, ?)" or similar.
508 The insertmanyvalues logic uses this string as a search and replace
509 target.
510
511 """
512
513 insert_crud_params: List[crud._CrudParamElementStr]
514 """List of Column / bind names etc. used while rewriting the statement"""
515
516 num_positional_params_counted: int
517 """the number of bound parameters in a single-row statement.
518
519 This count may be larger or smaller than the actual number of columns
520 targeted in the INSERT, as it accommodates for SQL expressions
521 in the values list that may have zero or more parameters embedded
522 within them.
523
524 This count is part of what's used to organize rewritten parameter lists
525 when batching.
526
527 """
528
529 sort_by_parameter_order: bool = False
530 """if the deterministic_returnined_order parameter were used on the
531 insert.
532
533 All of the attributes following this will only be used if this is True.
534
535 """
536
537 includes_upsert_behaviors: bool = False
538 """if True, we have to accommodate for upsert behaviors.
539
540 This will in some cases downgrade "insertmanyvalues" that requests
541 deterministic ordering.
542
543 """
544
545 sentinel_columns: Optional[Sequence[Column[Any]]] = None
546 """List of sentinel columns that were located.
547
548 This list is only here if the INSERT asked for
549 sort_by_parameter_order=True,
550 and dialect-appropriate sentinel columns were located.
551
552 .. versionadded:: 2.0.10
553
554 """
555
556 num_sentinel_columns: int = 0
557 """how many sentinel columns are in the above list, if any.
558
559 This is the same as
560 ``len(sentinel_columns) if sentinel_columns is not None else 0``
561
562 """
563
564 sentinel_param_keys: Optional[Sequence[str]] = None
565 """parameter str keys in each param dictionary / tuple
566 that would link to the client side "sentinel" values for that row, which
567 we can use to match up parameter sets to result rows.
568
569 This is only present if sentinel_columns is present and the INSERT
570 statement actually refers to client side values for these sentinel
571 columns.
572
573 .. versionadded:: 2.0.10
574
575 .. versionchanged:: 2.0.29 - the sequence is now string dictionary keys
576 only, used against the "compiled parameteters" collection before
577 the parameters were converted by bound parameter processors
578
579 """
580
581 implicit_sentinel: bool = False
582 """if True, we have exactly one sentinel column and it uses a server side
583 value, currently has to generate an incrementing integer value.
584
585 The dialect in question would have asserted that it supports receiving
586 these values back and sorting on that value as a means of guaranteeing
587 correlation with the incoming parameter list.
588
589 .. versionadded:: 2.0.10
590
591 """
592
593 embed_values_counter: bool = False
594 """Whether to embed an incrementing integer counter in each parameter
595 set within the VALUES clause as parameters are batched over.
596
597 This is only used for a specific INSERT..SELECT..VALUES..RETURNING syntax
598 where a subquery is used to produce value tuples. Current support
599 includes PostgreSQL, Microsoft SQL Server.
600
601 .. versionadded:: 2.0.10
602
603 """
604
605
606class _InsertManyValuesBatch(NamedTuple):
607 """represents an individual batch SQL statement for insertmanyvalues.
608
609 This is passed through the
610 :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and
611 :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods out
612 to the :class:`.Connection` within the
613 :meth:`.Connection._exec_insertmany_context` method.
614
615 .. versionadded:: 2.0.10
616
617 """
618
619 replaced_statement: str
620 replaced_parameters: _DBAPIAnyExecuteParams
621 processed_setinputsizes: Optional[_GenericSetInputSizesType]
622 batch: Sequence[_DBAPISingleExecuteParams]
623 sentinel_values: Sequence[Tuple[Any, ...]]
624 current_batch_size: int
625 batchnum: int
626 total_batches: int
627 rows_sorted: bool
628 is_downgraded: bool
629
630
631class InsertmanyvaluesSentinelOpts(FastIntFlag):
632 """bitflag enum indicating styles of PK defaults
633 which can work as implicit sentinel columns
634
635 """
636
637 NOT_SUPPORTED = 1
638 AUTOINCREMENT = 2
639 IDENTITY = 4
640 SEQUENCE = 8
641
642 ANY_AUTOINCREMENT = AUTOINCREMENT | IDENTITY | SEQUENCE
643 _SUPPORTED_OR_NOT = NOT_SUPPORTED | ANY_AUTOINCREMENT
644
645 USE_INSERT_FROM_SELECT = 16
646 RENDER_SELECT_COL_CASTS = 64
647
648
649class CompilerState(IntEnum):
650 COMPILING = 0
651 """statement is present, compilation phase in progress"""
652
653 STRING_APPLIED = 1
654 """statement is present, string form of the statement has been applied.
655
656 Additional processors by subclasses may still be pending.
657
658 """
659
660 NO_STATEMENT = 2
661 """compiler does not have a statement to compile, is used
662 for method access"""
663
664
665class Linting(IntEnum):
666 """represent preferences for the 'SQL linting' feature.
667
668 this feature currently includes support for flagging cartesian products
669 in SQL statements.
670
671 """
672
673 NO_LINTING = 0
674 "Disable all linting."
675
676 COLLECT_CARTESIAN_PRODUCTS = 1
677 """Collect data on FROMs and cartesian products and gather into
678 'self.from_linter'"""
679
680 WARN_LINTING = 2
681 "Emit warnings for linters that find problems"
682
683 FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING
684 """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS
685 and WARN_LINTING"""
686
687
688NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple(
689 Linting
690)
691
692
693class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
694 """represents current state for the "cartesian product" detection
695 feature."""
696
697 def lint(self, start=None):
698 froms = self.froms
699 if not froms:
700 return None, None
701
702 edges = set(self.edges)
703 the_rest = set(froms)
704
705 if start is not None:
706 start_with = start
707 the_rest.remove(start_with)
708 else:
709 start_with = the_rest.pop()
710
711 stack = collections.deque([start_with])
712
713 while stack and the_rest:
714 node = stack.popleft()
715 the_rest.discard(node)
716
717 # comparison of nodes in edges here is based on hash equality, as
718 # there are "annotated" elements that match the non-annotated ones.
719 # to remove the need for in-python hash() calls, use native
720 # containment routines (e.g. "node in edge", "edge.index(node)")
721 to_remove = {edge for edge in edges if node in edge}
722
723 # appendleft the node in each edge that is not
724 # the one that matched.
725 stack.extendleft(edge[not edge.index(node)] for edge in to_remove)
726 edges.difference_update(to_remove)
727
728 # FROMS left over? boom
729 if the_rest:
730 return the_rest, start_with
731 else:
732 return None, None
733
734 def warn(self, stmt_type="SELECT"):
735 the_rest, start_with = self.lint()
736
737 # FROMS left over? boom
738 if the_rest:
739 froms = the_rest
740 if froms:
741 template = (
742 "{stmt_type} statement has a cartesian product between "
743 "FROM element(s) {froms} and "
744 'FROM element "{start}". Apply join condition(s) '
745 "between each element to resolve."
746 )
747 froms_str = ", ".join(
748 f'"{self.froms[from_]}"' for from_ in froms
749 )
750 message = template.format(
751 stmt_type=stmt_type,
752 froms=froms_str,
753 start=self.froms[start_with],
754 )
755
756 util.warn(message)
757
758
759class Compiled:
760 """Represent a compiled SQL or DDL expression.
761
762 The ``__str__`` method of the ``Compiled`` object should produce
763 the actual text of the statement. ``Compiled`` objects are
764 specific to their underlying database dialect, and also may
765 or may not be specific to the columns referenced within a
766 particular set of bind parameters. In no case should the
767 ``Compiled`` object be dependent on the actual values of those
768 bind parameters, even though it may reference those values as
769 defaults.
770 """
771
772 statement: Optional[ClauseElement] = None
773 "The statement to compile."
774 string: str = ""
775 "The string representation of the ``statement``"
776
777 state: CompilerState
778 """description of the compiler's state"""
779
780 is_sql = False
781 is_ddl = False
782
783 _cached_metadata: Optional[CursorResultMetaData] = None
784
785 _result_columns: Optional[List[ResultColumnsEntry]] = None
786
787 schema_translate_map: Optional[SchemaTranslateMapType] = None
788
789 execution_options: _ExecuteOptions = util.EMPTY_DICT
790 """
791 Execution options propagated from the statement. In some cases,
792 sub-elements of the statement can modify these.
793 """
794
795 preparer: IdentifierPreparer
796
797 _annotations: _AnnotationDict = util.EMPTY_DICT
798
799 compile_state: Optional[CompileState] = None
800 """Optional :class:`.CompileState` object that maintains additional
801 state used by the compiler.
802
803 Major executable objects such as :class:`_expression.Insert`,
804 :class:`_expression.Update`, :class:`_expression.Delete`,
805 :class:`_expression.Select` will generate this
806 state when compiled in order to calculate additional information about the
807 object. For the top level object that is to be executed, the state can be
808 stored here where it can also have applicability towards result set
809 processing.
810
811 .. versionadded:: 1.4
812
813 """
814
815 dml_compile_state: Optional[CompileState] = None
816 """Optional :class:`.CompileState` assigned at the same point that
817 .isinsert, .isupdate, or .isdelete is assigned.
818
819 This will normally be the same object as .compile_state, with the
820 exception of cases like the :class:`.ORMFromStatementCompileState`
821 object.
822
823 .. versionadded:: 1.4.40
824
825 """
826
827 cache_key: Optional[CacheKey] = None
828 """The :class:`.CacheKey` that was generated ahead of creating this
829 :class:`.Compiled` object.
830
831 This is used for routines that need access to the original
832 :class:`.CacheKey` instance generated when the :class:`.Compiled`
833 instance was first cached, typically in order to reconcile
834 the original list of :class:`.BindParameter` objects with a
835 per-statement list that's generated on each call.
836
837 """
838
839 _gen_time: float
840 """Generation time of this :class:`.Compiled`, used for reporting
841 cache stats."""
842
843 def __init__(
844 self,
845 dialect: Dialect,
846 statement: Optional[ClauseElement],
847 schema_translate_map: Optional[SchemaTranslateMapType] = None,
848 render_schema_translate: bool = False,
849 compile_kwargs: Mapping[str, Any] = util.immutabledict(),
850 ):
851 """Construct a new :class:`.Compiled` object.
852
853 :param dialect: :class:`.Dialect` to compile against.
854
855 :param statement: :class:`_expression.ClauseElement` to be compiled.
856
857 :param schema_translate_map: dictionary of schema names to be
858 translated when forming the resultant SQL
859
860 .. seealso::
861
862 :ref:`schema_translating`
863
864 :param compile_kwargs: additional kwargs that will be
865 passed to the initial call to :meth:`.Compiled.process`.
866
867
868 """
869 self.dialect = dialect
870 self.preparer = self.dialect.identifier_preparer
871 if schema_translate_map:
872 self.schema_translate_map = schema_translate_map
873 self.preparer = self.preparer._with_schema_translate(
874 schema_translate_map
875 )
876
877 if statement is not None:
878 self.state = CompilerState.COMPILING
879 self.statement = statement
880 self.can_execute = statement.supports_execution
881 self._annotations = statement._annotations
882 if self.can_execute:
883 if TYPE_CHECKING:
884 assert isinstance(statement, Executable)
885 self.execution_options = statement._execution_options
886 self.string = self.process(self.statement, **compile_kwargs)
887
888 if render_schema_translate:
889 assert schema_translate_map is not None
890 self.string = self.preparer._render_schema_translates(
891 self.string, schema_translate_map
892 )
893
894 self.state = CompilerState.STRING_APPLIED
895 else:
896 self.state = CompilerState.NO_STATEMENT
897
898 self._gen_time = perf_counter()
899
900 def __init_subclass__(cls) -> None:
901 cls._init_compiler_cls()
902 return super().__init_subclass__()
903
904 @classmethod
905 def _init_compiler_cls(cls):
906 pass
907
908 def _execute_on_connection(
909 self, connection, distilled_params, execution_options
910 ):
911 if self.can_execute:
912 return connection._execute_compiled(
913 self, distilled_params, execution_options
914 )
915 else:
916 raise exc.ObjectNotExecutableError(self.statement)
917
918 def visit_unsupported_compilation(self, element, err, **kw):
919 raise exc.UnsupportedCompilationError(self, type(element)) from err
920
921 @property
922 def sql_compiler(self) -> SQLCompiler:
923 """Return a Compiled that is capable of processing SQL expressions.
924
925 If this compiler is one, it would likely just return 'self'.
926
927 """
928
929 raise NotImplementedError()
930
931 def process(self, obj: Visitable, **kwargs: Any) -> str:
932 return obj._compiler_dispatch(self, **kwargs)
933
934 def __str__(self) -> str:
935 """Return the string text of the generated SQL or DDL."""
936
937 if self.state is CompilerState.STRING_APPLIED:
938 return self.string
939 else:
940 return ""
941
942 def construct_params(
943 self,
944 params: Optional[_CoreSingleExecuteParams] = None,
945 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
946 escape_names: bool = True,
947 ) -> Optional[_MutableCoreSingleExecuteParams]:
948 """Return the bind params for this compiled object.
949
950 :param params: a dict of string/object pairs whose values will
951 override bind values compiled in to the
952 statement.
953 """
954
955 raise NotImplementedError()
956
957 @property
958 def params(self):
959 """Return the bind params for this compiled object."""
960 return self.construct_params()
961
962
963class TypeCompiler(util.EnsureKWArg):
964 """Produces DDL specification for TypeEngine objects."""
965
966 ensure_kwarg = r"visit_\w+"
967
968 def __init__(self, dialect: Dialect):
969 self.dialect = dialect
970
971 def process(self, type_: TypeEngine[Any], **kw: Any) -> str:
972 if (
973 type_._variant_mapping
974 and self.dialect.name in type_._variant_mapping
975 ):
976 type_ = type_._variant_mapping[self.dialect.name]
977 return type_._compiler_dispatch(self, **kw)
978
979 def visit_unsupported_compilation(
980 self, element: Any, err: Exception, **kw: Any
981 ) -> NoReturn:
982 raise exc.UnsupportedCompilationError(self, element) from err
983
984
985# this was a Visitable, but to allow accurate detection of
986# column elements this is actually a column element
987class _CompileLabel(
988 roles.BinaryElementRole[Any], elements.CompilerColumnElement
989):
990 """lightweight label object which acts as an expression.Label."""
991
992 __visit_name__ = "label"
993 __slots__ = "element", "name", "_alt_names"
994
995 def __init__(self, col, name, alt_names=()):
996 self.element = col
997 self.name = name
998 self._alt_names = (col,) + alt_names
999
1000 @property
1001 def proxy_set(self):
1002 return self.element.proxy_set
1003
1004 @property
1005 def type(self):
1006 return self.element.type
1007
1008 def self_group(self, **kw):
1009 return self
1010
1011
1012class ilike_case_insensitive(
1013 roles.BinaryElementRole[Any], elements.CompilerColumnElement
1014):
1015 """produce a wrapping element for a case-insensitive portion of
1016 an ILIKE construct.
1017
1018 The construct usually renders the ``lower()`` function, but on
1019 PostgreSQL will pass silently with the assumption that "ILIKE"
1020 is being used.
1021
1022 .. versionadded:: 2.0
1023
1024 """
1025
1026 __visit_name__ = "ilike_case_insensitive_operand"
1027 __slots__ = "element", "comparator"
1028
1029 def __init__(self, element):
1030 self.element = element
1031 self.comparator = element.comparator
1032
1033 @property
1034 def proxy_set(self):
1035 return self.element.proxy_set
1036
1037 @property
1038 def type(self):
1039 return self.element.type
1040
1041 def self_group(self, **kw):
1042 return self
1043
1044 def _with_binary_element_type(self, type_):
1045 return ilike_case_insensitive(
1046 self.element._with_binary_element_type(type_)
1047 )
1048
1049
1050class SQLCompiler(Compiled):
1051 """Default implementation of :class:`.Compiled`.
1052
1053 Compiles :class:`_expression.ClauseElement` objects into SQL strings.
1054
1055 """
1056
1057 extract_map = EXTRACT_MAP
1058
1059 bindname_escape_characters: ClassVar[Mapping[str, str]] = (
1060 util.immutabledict(
1061 {
1062 "%": "P",
1063 "(": "A",
1064 ")": "Z",
1065 ":": "C",
1066 ".": "_",
1067 "[": "_",
1068 "]": "_",
1069 " ": "_",
1070 }
1071 )
1072 )
1073 """A mapping (e.g. dict or similar) containing a lookup of
1074 characters keyed to replacement characters which will be applied to all
1075 'bind names' used in SQL statements as a form of 'escaping'; the given
1076 characters are replaced entirely with the 'replacement' character when
1077 rendered in the SQL statement, and a similar translation is performed
1078 on the incoming names used in parameter dictionaries passed to methods
1079 like :meth:`_engine.Connection.execute`.
1080
1081 This allows bound parameter names used in :func:`_sql.bindparam` and
1082 other constructs to have any arbitrary characters present without any
1083 concern for characters that aren't allowed at all on the target database.
1084
1085 Third party dialects can establish their own dictionary here to replace the
1086 default mapping, which will ensure that the particular characters in the
1087 mapping will never appear in a bound parameter name.
1088
1089 The dictionary is evaluated at **class creation time**, so cannot be
1090 modified at runtime; it must be present on the class when the class
1091 is first declared.
1092
1093 Note that for dialects that have additional bound parameter rules such
1094 as additional restrictions on leading characters, the
1095 :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented.
1096 See the cx_Oracle compiler for an example of this.
1097
1098 .. versionadded:: 2.0.0rc1
1099
1100 """
1101
1102 _bind_translate_re: ClassVar[Pattern[str]]
1103 _bind_translate_chars: ClassVar[Mapping[str, str]]
1104
1105 is_sql = True
1106
1107 compound_keywords = COMPOUND_KEYWORDS
1108
1109 isdelete: bool = False
1110 isinsert: bool = False
1111 isupdate: bool = False
1112 """class-level defaults which can be set at the instance
1113 level to define if this Compiled instance represents
1114 INSERT/UPDATE/DELETE
1115 """
1116
1117 postfetch: Optional[List[Column[Any]]]
1118 """list of columns that can be post-fetched after INSERT or UPDATE to
1119 receive server-updated values"""
1120
1121 insert_prefetch: Sequence[Column[Any]] = ()
1122 """list of columns for which default values should be evaluated before
1123 an INSERT takes place"""
1124
1125 update_prefetch: Sequence[Column[Any]] = ()
1126 """list of columns for which onupdate default values should be evaluated
1127 before an UPDATE takes place"""
1128
1129 implicit_returning: Optional[Sequence[ColumnElement[Any]]] = None
1130 """list of "implicit" returning columns for a toplevel INSERT or UPDATE
1131 statement, used to receive newly generated values of columns.
1132
1133 .. versionadded:: 2.0 ``implicit_returning`` replaces the previous
1134 ``returning`` collection, which was not a generalized RETURNING
1135 collection and instead was in fact specific to the "implicit returning"
1136 feature.
1137
1138 """
1139
1140 isplaintext: bool = False
1141
1142 binds: Dict[str, BindParameter[Any]]
1143 """a dictionary of bind parameter keys to BindParameter instances."""
1144
1145 bind_names: Dict[BindParameter[Any], str]
1146 """a dictionary of BindParameter instances to "compiled" names
1147 that are actually present in the generated SQL"""
1148
1149 stack: List[_CompilerStackEntry]
1150 """major statements such as SELECT, INSERT, UPDATE, DELETE are
1151 tracked in this stack using an entry format."""
1152
1153 returning_precedes_values: bool = False
1154 """set to True classwide to generate RETURNING
1155 clauses before the VALUES or WHERE clause (i.e. MSSQL)
1156 """
1157
1158 render_table_with_column_in_update_from: bool = False
1159 """set to True classwide to indicate the SET clause
1160 in a multi-table UPDATE statement should qualify
1161 columns with the table name (i.e. MySQL only)
1162 """
1163
1164 ansi_bind_rules: bool = False
1165 """SQL 92 doesn't allow bind parameters to be used
1166 in the columns clause of a SELECT, nor does it allow
1167 ambiguous expressions like "? = ?". A compiler
1168 subclass can set this flag to False if the target
1169 driver/DB enforces this
1170 """
1171
1172 bindtemplate: str
1173 """template to render bound parameters based on paramstyle."""
1174
1175 compilation_bindtemplate: str
1176 """template used by compiler to render parameters before positional
1177 paramstyle application"""
1178
1179 _numeric_binds_identifier_char: str
1180 """Character that's used to as the identifier of a numerical bind param.
1181 For example if this char is set to ``$``, numerical binds will be rendered
1182 in the form ``$1, $2, $3``.
1183 """
1184
1185 _result_columns: List[ResultColumnsEntry]
1186 """relates label names in the final SQL to a tuple of local
1187 column/label name, ColumnElement object (if any) and
1188 TypeEngine. CursorResult uses this for type processing and
1189 column targeting"""
1190
1191 _textual_ordered_columns: bool = False
1192 """tell the result object that the column names as rendered are important,
1193 but they are also "ordered" vs. what is in the compiled object here.
1194
1195 As of 1.4.42 this condition is only present when the statement is a
1196 TextualSelect, e.g. text("....").columns(...), where it is required
1197 that the columns are considered positionally and not by name.
1198
1199 """
1200
1201 _ad_hoc_textual: bool = False
1202 """tell the result that we encountered text() or '*' constructs in the
1203 middle of the result columns, but we also have compiled columns, so
1204 if the number of columns in cursor.description does not match how many
1205 expressions we have, that means we can't rely on positional at all and
1206 should match on name.
1207
1208 """
1209
1210 _ordered_columns: bool = True
1211 """
1212 if False, means we can't be sure the list of entries
1213 in _result_columns is actually the rendered order. Usually
1214 True unless using an unordered TextualSelect.
1215 """
1216
1217 _loose_column_name_matching: bool = False
1218 """tell the result object that the SQL statement is textual, wants to match
1219 up to Column objects, and may be using the ._tq_label in the SELECT rather
1220 than the base name.
1221
1222 """
1223
1224 _numeric_binds: bool = False
1225 """
1226 True if paramstyle is "numeric". This paramstyle is trickier than
1227 all the others.
1228
1229 """
1230
1231 _render_postcompile: bool = False
1232 """
1233 whether to render out POSTCOMPILE params during the compile phase.
1234
1235 This attribute is used only for end-user invocation of stmt.compile();
1236 it's never used for actual statement execution, where instead the
1237 dialect internals access and render the internal postcompile structure
1238 directly.
1239
1240 """
1241
1242 _post_compile_expanded_state: Optional[ExpandedState] = None
1243 """When render_postcompile is used, the ``ExpandedState`` used to create
1244 the "expanded" SQL is assigned here, and then used by the ``.params``
1245 accessor and ``.construct_params()`` methods for their return values.
1246
1247 .. versionadded:: 2.0.0rc1
1248
1249 """
1250
1251 _pre_expanded_string: Optional[str] = None
1252 """Stores the original string SQL before 'post_compile' is applied,
1253 for cases where 'post_compile' were used.
1254
1255 """
1256
1257 _pre_expanded_positiontup: Optional[List[str]] = None
1258
1259 _insertmanyvalues: Optional[_InsertManyValues] = None
1260
1261 _insert_crud_params: Optional[crud._CrudParamSequence] = None
1262
1263 literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
1264 """bindparameter objects that are rendered as literal values at statement
1265 execution time.
1266
1267 """
1268
1269 post_compile_params: FrozenSet[BindParameter[Any]] = frozenset()
1270 """bindparameter objects that are rendered as bound parameter placeholders
1271 at statement execution time.
1272
1273 """
1274
1275 escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT
1276 """Late escaping of bound parameter names that has to be converted
1277 to the original name when looking in the parameter dictionary.
1278
1279 """
1280
1281 has_out_parameters = False
1282 """if True, there are bindparam() objects that have the isoutparam
1283 flag set."""
1284
1285 postfetch_lastrowid = False
1286 """if True, and this in insert, use cursor.lastrowid to populate
1287 result.inserted_primary_key. """
1288
1289 _cache_key_bind_match: Optional[
1290 Tuple[
1291 Dict[
1292 BindParameter[Any],
1293 List[BindParameter[Any]],
1294 ],
1295 Dict[
1296 str,
1297 BindParameter[Any],
1298 ],
1299 ]
1300 ] = None
1301 """a mapping that will relate the BindParameter object we compile
1302 to those that are part of the extracted collection of parameters
1303 in the cache key, if we were given a cache key.
1304
1305 """
1306
1307 positiontup: Optional[List[str]] = None
1308 """for a compiled construct that uses a positional paramstyle, will be
1309 a sequence of strings, indicating the names of bound parameters in order.
1310
1311 This is used in order to render bound parameters in their correct order,
1312 and is combined with the :attr:`_sql.Compiled.params` dictionary to
1313 render parameters.
1314
1315 This sequence always contains the unescaped name of the parameters.
1316
1317 .. seealso::
1318
1319 :ref:`faq_sql_expression_string` - includes a usage example for
1320 debugging use cases.
1321
1322 """
1323 _values_bindparam: Optional[List[str]] = None
1324
1325 _visited_bindparam: Optional[List[str]] = None
1326
1327 inline: bool = False
1328
1329 ctes: Optional[MutableMapping[CTE, str]]
1330
1331 # Detect same CTE references - Dict[(level, name), cte]
1332 # Level is required for supporting nesting
1333 ctes_by_level_name: Dict[Tuple[int, str], CTE]
1334
1335 # To retrieve key/level in ctes_by_level_name -
1336 # Dict[cte_reference, (level, cte_name, cte_opts)]
1337 level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
1338
1339 ctes_recursive: bool
1340
1341 _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]")
1342 _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s")
1343 _positional_pattern = re.compile(
1344 f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
1345 )
1346
1347 @classmethod
1348 def _init_compiler_cls(cls):
1349 cls._init_bind_translate()
1350
1351 @classmethod
1352 def _init_bind_translate(cls):
1353 reg = re.escape("".join(cls.bindname_escape_characters))
1354 cls._bind_translate_re = re.compile(f"[{reg}]")
1355 cls._bind_translate_chars = cls.bindname_escape_characters
1356
1357 def __init__(
1358 self,
1359 dialect: Dialect,
1360 statement: Optional[ClauseElement],
1361 cache_key: Optional[CacheKey] = None,
1362 column_keys: Optional[Sequence[str]] = None,
1363 for_executemany: bool = False,
1364 linting: Linting = NO_LINTING,
1365 _supporting_against: Optional[SQLCompiler] = None,
1366 **kwargs: Any,
1367 ):
1368 """Construct a new :class:`.SQLCompiler` object.
1369
1370 :param dialect: :class:`.Dialect` to be used
1371
1372 :param statement: :class:`_expression.ClauseElement` to be compiled
1373
1374 :param column_keys: a list of column names to be compiled into an
1375 INSERT or UPDATE statement.
1376
1377 :param for_executemany: whether INSERT / UPDATE statements should
1378 expect that they are to be invoked in an "executemany" style,
1379 which may impact how the statement will be expected to return the
1380 values of defaults and autoincrement / sequences and similar.
1381 Depending on the backend and driver in use, support for retrieving
1382 these values may be disabled which means SQL expressions may
1383 be rendered inline, RETURNING may not be rendered, etc.
1384
1385 :param kwargs: additional keyword arguments to be consumed by the
1386 superclass.
1387
1388 """
1389 self.column_keys = column_keys
1390
1391 self.cache_key = cache_key
1392
1393 if cache_key:
1394 cksm = {b.key: b for b in cache_key[1]}
1395 ckbm = {b: [b] for b in cache_key[1]}
1396 self._cache_key_bind_match = (ckbm, cksm)
1397
1398 # compile INSERT/UPDATE defaults/sequences to expect executemany
1399 # style execution, which may mean no pre-execute of defaults,
1400 # or no RETURNING
1401 self.for_executemany = for_executemany
1402
1403 self.linting = linting
1404
1405 # a dictionary of bind parameter keys to BindParameter
1406 # instances.
1407 self.binds = {}
1408
1409 # a dictionary of BindParameter instances to "compiled" names
1410 # that are actually present in the generated SQL
1411 self.bind_names = util.column_dict()
1412
1413 # stack which keeps track of nested SELECT statements
1414 self.stack = []
1415
1416 self._result_columns = []
1417
1418 # true if the paramstyle is positional
1419 self.positional = dialect.positional
1420 if self.positional:
1421 self._numeric_binds = nb = dialect.paramstyle.startswith("numeric")
1422 if nb:
1423 self._numeric_binds_identifier_char = (
1424 "$" if dialect.paramstyle == "numeric_dollar" else ":"
1425 )
1426
1427 self.compilation_bindtemplate = _pyformat_template
1428 else:
1429 self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
1430
1431 self.ctes = None
1432
1433 self.label_length = (
1434 dialect.label_length or dialect.max_identifier_length
1435 )
1436
1437 # a map which tracks "anonymous" identifiers that are created on
1438 # the fly here
1439 self.anon_map = prefix_anon_map()
1440
1441 # a map which tracks "truncated" names based on
1442 # dialect.label_length or dialect.max_identifier_length
1443 self.truncated_names: Dict[Tuple[str, str], str] = {}
1444 self._truncated_counters: Dict[str, int] = {}
1445
1446 Compiled.__init__(self, dialect, statement, **kwargs)
1447
1448 if self.isinsert or self.isupdate or self.isdelete:
1449 if TYPE_CHECKING:
1450 assert isinstance(statement, UpdateBase)
1451
1452 if self.isinsert or self.isupdate:
1453 if TYPE_CHECKING:
1454 assert isinstance(statement, ValuesBase)
1455 if statement._inline:
1456 self.inline = True
1457 elif self.for_executemany and (
1458 not self.isinsert
1459 or (
1460 self.dialect.insert_executemany_returning
1461 and statement._return_defaults
1462 )
1463 ):
1464 self.inline = True
1465
1466 self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
1467
1468 if _supporting_against:
1469 self.__dict__.update(
1470 {
1471 k: v
1472 for k, v in _supporting_against.__dict__.items()
1473 if k
1474 not in {
1475 "state",
1476 "dialect",
1477 "preparer",
1478 "positional",
1479 "_numeric_binds",
1480 "compilation_bindtemplate",
1481 "bindtemplate",
1482 }
1483 }
1484 )
1485
1486 if self.state is CompilerState.STRING_APPLIED:
1487 if self.positional:
1488 if self._numeric_binds:
1489 self._process_numeric()
1490 else:
1491 self._process_positional()
1492
1493 if self._render_postcompile:
1494 parameters = self.construct_params(
1495 escape_names=False,
1496 _no_postcompile=True,
1497 )
1498
1499 self._process_parameters_for_postcompile(
1500 parameters, _populate_self=True
1501 )
1502
1503 @property
1504 def insert_single_values_expr(self) -> Optional[str]:
1505 """When an INSERT is compiled with a single set of parameters inside
1506 a VALUES expression, the string is assigned here, where it can be
1507 used for insert batching schemes to rewrite the VALUES expression.
1508
1509 .. versionadded:: 1.3.8
1510
1511 .. versionchanged:: 2.0 This collection is no longer used by
1512 SQLAlchemy's built-in dialects, in favor of the currently
1513 internal ``_insertmanyvalues`` collection that is used only by
1514 :class:`.SQLCompiler`.
1515
1516 """
1517 if self._insertmanyvalues is None:
1518 return None
1519 else:
1520 return self._insertmanyvalues.single_values_expr
1521
1522 @util.ro_memoized_property
1523 def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]:
1524 """The effective "returning" columns for INSERT, UPDATE or DELETE.
1525
1526 This is either the so-called "implicit returning" columns which are
1527 calculated by the compiler on the fly, or those present based on what's
1528 present in ``self.statement._returning`` (expanded into individual
1529 columns using the ``._all_selected_columns`` attribute) i.e. those set
1530 explicitly using the :meth:`.UpdateBase.returning` method.
1531
1532 .. versionadded:: 2.0
1533
1534 """
1535 if self.implicit_returning:
1536 return self.implicit_returning
1537 elif self.statement is not None and is_dml(self.statement):
1538 return [
1539 c
1540 for c in self.statement._all_selected_columns
1541 if is_column_element(c)
1542 ]
1543
1544 else:
1545 return None
1546
1547 @property
1548 def returning(self):
1549 """backwards compatibility; returns the
1550 effective_returning collection.
1551
1552 """
1553 return self.effective_returning
1554
1555 @property
1556 def current_executable(self):
1557 """Return the current 'executable' that is being compiled.
1558
1559 This is currently the :class:`_sql.Select`, :class:`_sql.Insert`,
1560 :class:`_sql.Update`, :class:`_sql.Delete`,
1561 :class:`_sql.CompoundSelect` object that is being compiled.
1562 Specifically it's assigned to the ``self.stack`` list of elements.
1563
1564 When a statement like the above is being compiled, it normally
1565 is also assigned to the ``.statement`` attribute of the
1566 :class:`_sql.Compiler` object. However, all SQL constructs are
1567 ultimately nestable, and this attribute should never be consulted
1568 by a ``visit_`` method, as it is not guaranteed to be assigned
1569 nor guaranteed to correspond to the current statement being compiled.
1570
1571 .. versionadded:: 1.3.21
1572
1573 For compatibility with previous versions, use the following
1574 recipe::
1575
1576 statement = getattr(self, "current_executable", False)
1577 if statement is False:
1578 statement = self.stack[-1]["selectable"]
1579
1580 For versions 1.4 and above, ensure only .current_executable
1581 is used; the format of "self.stack" may change.
1582
1583
1584 """
1585 try:
1586 return self.stack[-1]["selectable"]
1587 except IndexError as ie:
1588 raise IndexError("Compiler does not have a stack entry") from ie
1589
1590 @property
1591 def prefetch(self):
1592 return list(self.insert_prefetch) + list(self.update_prefetch)
1593
1594 @util.memoized_property
1595 def _global_attributes(self) -> Dict[Any, Any]:
1596 return {}
1597
1598 @util.memoized_instancemethod
1599 def _init_cte_state(self) -> MutableMapping[CTE, str]:
1600 """Initialize collections related to CTEs only if
1601 a CTE is located, to save on the overhead of
1602 these collections otherwise.
1603
1604 """
1605 # collect CTEs to tack on top of a SELECT
1606 # To store the query to print - Dict[cte, text_query]
1607 ctes: MutableMapping[CTE, str] = util.OrderedDict()
1608 self.ctes = ctes
1609
1610 # Detect same CTE references - Dict[(level, name), cte]
1611 # Level is required for supporting nesting
1612 self.ctes_by_level_name = {}
1613
1614 # To retrieve key/level in ctes_by_level_name -
1615 # Dict[cte_reference, (level, cte_name, cte_opts)]
1616 self.level_name_by_cte = {}
1617
1618 self.ctes_recursive = False
1619
1620 return ctes
1621
1622 @contextlib.contextmanager
1623 def _nested_result(self):
1624 """special API to support the use case of 'nested result sets'"""
1625 result_columns, ordered_columns = (
1626 self._result_columns,
1627 self._ordered_columns,
1628 )
1629 self._result_columns, self._ordered_columns = [], False
1630
1631 try:
1632 if self.stack:
1633 entry = self.stack[-1]
1634 entry["need_result_map_for_nested"] = True
1635 else:
1636 entry = None
1637 yield self._result_columns, self._ordered_columns
1638 finally:
1639 if entry:
1640 entry.pop("need_result_map_for_nested")
1641 self._result_columns, self._ordered_columns = (
1642 result_columns,
1643 ordered_columns,
1644 )
1645
1646 def _process_positional(self):
1647 assert not self.positiontup
1648 assert self.state is CompilerState.STRING_APPLIED
1649 assert not self._numeric_binds
1650
1651 if self.dialect.paramstyle == "format":
1652 placeholder = "%s"
1653 else:
1654 assert self.dialect.paramstyle == "qmark"
1655 placeholder = "?"
1656
1657 positions = []
1658
1659 def find_position(m: re.Match[str]) -> str:
1660 normal_bind = m.group(1)
1661 if normal_bind:
1662 positions.append(normal_bind)
1663 return placeholder
1664 else:
1665 # this a post-compile bind
1666 positions.append(m.group(2))
1667 return m.group(0)
1668
1669 self.string = re.sub(
1670 self._positional_pattern, find_position, self.string
1671 )
1672
1673 if self.escaped_bind_names:
1674 reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
1675 assert len(self.escaped_bind_names) == len(reverse_escape)
1676 self.positiontup = [
1677 reverse_escape.get(name, name) for name in positions
1678 ]
1679 else:
1680 self.positiontup = positions
1681
1682 if self._insertmanyvalues:
1683 positions = []
1684
1685 single_values_expr = re.sub(
1686 self._positional_pattern,
1687 find_position,
1688 self._insertmanyvalues.single_values_expr,
1689 )
1690 insert_crud_params = [
1691 (
1692 v[0],
1693 v[1],
1694 re.sub(self._positional_pattern, find_position, v[2]),
1695 v[3],
1696 )
1697 for v in self._insertmanyvalues.insert_crud_params
1698 ]
1699
1700 self._insertmanyvalues = self._insertmanyvalues._replace(
1701 single_values_expr=single_values_expr,
1702 insert_crud_params=insert_crud_params,
1703 )
1704
1705 def _process_numeric(self):
1706 assert self._numeric_binds
1707 assert self.state is CompilerState.STRING_APPLIED
1708
1709 num = 1
1710 param_pos: Dict[str, str] = {}
1711 order: Iterable[str]
1712 if self._insertmanyvalues and self._values_bindparam is not None:
1713 # bindparams that are not in values are always placed first.
1714 # this avoids the need of changing them when using executemany
1715 # values () ()
1716 order = itertools.chain(
1717 (
1718 name
1719 for name in self.bind_names.values()
1720 if name not in self._values_bindparam
1721 ),
1722 self.bind_names.values(),
1723 )
1724 else:
1725 order = self.bind_names.values()
1726
1727 for bind_name in order:
1728 if bind_name in param_pos:
1729 continue
1730 bind = self.binds[bind_name]
1731 if (
1732 bind in self.post_compile_params
1733 or bind in self.literal_execute_params
1734 ):
1735 # set to None to just mark the in positiontup, it will not
1736 # be replaced below.
1737 param_pos[bind_name] = None # type: ignore
1738 else:
1739 ph = f"{self._numeric_binds_identifier_char}{num}"
1740 num += 1
1741 param_pos[bind_name] = ph
1742
1743 self.next_numeric_pos = num
1744
1745 self.positiontup = list(param_pos)
1746 if self.escaped_bind_names:
1747 len_before = len(param_pos)
1748 param_pos = {
1749 self.escaped_bind_names.get(name, name): pos
1750 for name, pos in param_pos.items()
1751 }
1752 assert len(param_pos) == len_before
1753
1754 # Can't use format here since % chars are not escaped.
1755 self.string = self._pyformat_pattern.sub(
1756 lambda m: param_pos[m.group(1)], self.string
1757 )
1758
1759 if self._insertmanyvalues:
1760 single_values_expr = (
1761 # format is ok here since single_values_expr includes only
1762 # place-holders
1763 self._insertmanyvalues.single_values_expr
1764 % param_pos
1765 )
1766 insert_crud_params = [
1767 (v[0], v[1], "%s", v[3])
1768 for v in self._insertmanyvalues.insert_crud_params
1769 ]
1770
1771 self._insertmanyvalues = self._insertmanyvalues._replace(
1772 # This has the numbers (:1, :2)
1773 single_values_expr=single_values_expr,
1774 # The single binds are instead %s so they can be formatted
1775 insert_crud_params=insert_crud_params,
1776 )
1777
1778 @util.memoized_property
1779 def _bind_processors(
1780 self,
1781 ) -> MutableMapping[
1782 str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]]
1783 ]:
1784 # mypy is not able to see the two value types as the above Union,
1785 # it just sees "object". don't know how to resolve
1786 return {
1787 key: value # type: ignore
1788 for key, value in (
1789 (
1790 self.bind_names[bindparam],
1791 (
1792 bindparam.type._cached_bind_processor(self.dialect)
1793 if not bindparam.type._is_tuple_type
1794 else tuple(
1795 elem_type._cached_bind_processor(self.dialect)
1796 for elem_type in cast(
1797 TupleType, bindparam.type
1798 ).types
1799 )
1800 ),
1801 )
1802 for bindparam in self.bind_names
1803 )
1804 if value is not None
1805 }
1806
1807 def is_subquery(self):
1808 return len(self.stack) > 1
1809
1810 @property
1811 def sql_compiler(self) -> Self:
1812 return self
1813
1814 def construct_expanded_state(
1815 self,
1816 params: Optional[_CoreSingleExecuteParams] = None,
1817 escape_names: bool = True,
1818 ) -> ExpandedState:
1819 """Return a new :class:`.ExpandedState` for a given parameter set.
1820
1821 For queries that use "expanding" or other late-rendered parameters,
1822 this method will provide for both the finalized SQL string as well
1823 as the parameters that would be used for a particular parameter set.
1824
1825 .. versionadded:: 2.0.0rc1
1826
1827 """
1828 parameters = self.construct_params(
1829 params,
1830 escape_names=escape_names,
1831 _no_postcompile=True,
1832 )
1833 return self._process_parameters_for_postcompile(
1834 parameters,
1835 )
1836
1837 def construct_params(
1838 self,
1839 params: Optional[_CoreSingleExecuteParams] = None,
1840 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
1841 escape_names: bool = True,
1842 _group_number: Optional[int] = None,
1843 _check: bool = True,
1844 _no_postcompile: bool = False,
1845 ) -> _MutableCoreSingleExecuteParams:
1846 """return a dictionary of bind parameter keys and values"""
1847
1848 if self._render_postcompile and not _no_postcompile:
1849 assert self._post_compile_expanded_state is not None
1850 if not params:
1851 return dict(self._post_compile_expanded_state.parameters)
1852 else:
1853 raise exc.InvalidRequestError(
1854 "can't construct new parameters when render_postcompile "
1855 "is used; the statement is hard-linked to the original "
1856 "parameters. Use construct_expanded_state to generate a "
1857 "new statement and parameters."
1858 )
1859
1860 has_escaped_names = escape_names and bool(self.escaped_bind_names)
1861
1862 if extracted_parameters:
1863 # related the bound parameters collected in the original cache key
1864 # to those collected in the incoming cache key. They will not have
1865 # matching names but they will line up positionally in the same
1866 # way. The parameters present in self.bind_names may be clones of
1867 # these original cache key params in the case of DML but the .key
1868 # will be guaranteed to match.
1869 if self.cache_key is None:
1870 raise exc.CompileError(
1871 "This compiled object has no original cache key; "
1872 "can't pass extracted_parameters to construct_params"
1873 )
1874 else:
1875 orig_extracted = self.cache_key[1]
1876
1877 ckbm_tuple = self._cache_key_bind_match
1878 assert ckbm_tuple is not None
1879 ckbm, _ = ckbm_tuple
1880 resolved_extracted = {
1881 bind: extracted
1882 for b, extracted in zip(orig_extracted, extracted_parameters)
1883 for bind in ckbm[b]
1884 }
1885 else:
1886 resolved_extracted = None
1887
1888 if params:
1889 pd = {}
1890 for bindparam, name in self.bind_names.items():
1891 escaped_name = (
1892 self.escaped_bind_names.get(name, name)
1893 if has_escaped_names
1894 else name
1895 )
1896
1897 if bindparam.key in params:
1898 pd[escaped_name] = params[bindparam.key]
1899 elif name in params:
1900 pd[escaped_name] = params[name]
1901
1902 elif _check and bindparam.required:
1903 if _group_number:
1904 raise exc.InvalidRequestError(
1905 "A value is required for bind parameter %r, "
1906 "in parameter group %d"
1907 % (bindparam.key, _group_number),
1908 code="cd3x",
1909 )
1910 else:
1911 raise exc.InvalidRequestError(
1912 "A value is required for bind parameter %r"
1913 % bindparam.key,
1914 code="cd3x",
1915 )
1916 else:
1917 if resolved_extracted:
1918 value_param = resolved_extracted.get(
1919 bindparam, bindparam
1920 )
1921 else:
1922 value_param = bindparam
1923
1924 if bindparam.callable:
1925 pd[escaped_name] = value_param.effective_value
1926 else:
1927 pd[escaped_name] = value_param.value
1928 return pd
1929 else:
1930 pd = {}
1931 for bindparam, name in self.bind_names.items():
1932 escaped_name = (
1933 self.escaped_bind_names.get(name, name)
1934 if has_escaped_names
1935 else name
1936 )
1937
1938 if _check and bindparam.required:
1939 if _group_number:
1940 raise exc.InvalidRequestError(
1941 "A value is required for bind parameter %r, "
1942 "in parameter group %d"
1943 % (bindparam.key, _group_number),
1944 code="cd3x",
1945 )
1946 else:
1947 raise exc.InvalidRequestError(
1948 "A value is required for bind parameter %r"
1949 % bindparam.key,
1950 code="cd3x",
1951 )
1952
1953 if resolved_extracted:
1954 value_param = resolved_extracted.get(bindparam, bindparam)
1955 else:
1956 value_param = bindparam
1957
1958 if bindparam.callable:
1959 pd[escaped_name] = value_param.effective_value
1960 else:
1961 pd[escaped_name] = value_param.value
1962
1963 return pd
1964
1965 @util.memoized_instancemethod
1966 def _get_set_input_sizes_lookup(self):
1967 dialect = self.dialect
1968
1969 include_types = dialect.include_set_input_sizes
1970 exclude_types = dialect.exclude_set_input_sizes
1971
1972 dbapi = dialect.dbapi
1973
1974 def lookup_type(typ):
1975 dbtype = typ._unwrapped_dialect_impl(dialect).get_dbapi_type(dbapi)
1976
1977 if (
1978 dbtype is not None
1979 and (exclude_types is None or dbtype not in exclude_types)
1980 and (include_types is None or dbtype in include_types)
1981 ):
1982 return dbtype
1983 else:
1984 return None
1985
1986 inputsizes = {}
1987
1988 literal_execute_params = self.literal_execute_params
1989
1990 for bindparam in self.bind_names:
1991 if bindparam in literal_execute_params:
1992 continue
1993
1994 if bindparam.type._is_tuple_type:
1995 inputsizes[bindparam] = [
1996 lookup_type(typ)
1997 for typ in cast(TupleType, bindparam.type).types
1998 ]
1999 else:
2000 inputsizes[bindparam] = lookup_type(bindparam.type)
2001
2002 return inputsizes
2003
2004 @property
2005 def params(self):
2006 """Return the bind param dictionary embedded into this
2007 compiled object, for those values that are present.
2008
2009 .. seealso::
2010
2011 :ref:`faq_sql_expression_string` - includes a usage example for
2012 debugging use cases.
2013
2014 """
2015 return self.construct_params(_check=False)
2016
2017 def _process_parameters_for_postcompile(
2018 self,
2019 parameters: _MutableCoreSingleExecuteParams,
2020 _populate_self: bool = False,
2021 ) -> ExpandedState:
2022 """handle special post compile parameters.
2023
2024 These include:
2025
2026 * "expanding" parameters -typically IN tuples that are rendered
2027 on a per-parameter basis for an otherwise fixed SQL statement string.
2028
2029 * literal_binds compiled with the literal_execute flag. Used for
2030 things like SQL Server "TOP N" where the driver does not accommodate
2031 N as a bound parameter.
2032
2033 """
2034
2035 expanded_parameters = {}
2036 new_positiontup: Optional[List[str]]
2037
2038 pre_expanded_string = self._pre_expanded_string
2039 if pre_expanded_string is None:
2040 pre_expanded_string = self.string
2041
2042 if self.positional:
2043 new_positiontup = []
2044
2045 pre_expanded_positiontup = self._pre_expanded_positiontup
2046 if pre_expanded_positiontup is None:
2047 pre_expanded_positiontup = self.positiontup
2048
2049 else:
2050 new_positiontup = pre_expanded_positiontup = None
2051
2052 processors = self._bind_processors
2053 single_processors = cast(
2054 "Mapping[str, _BindProcessorType[Any]]", processors
2055 )
2056 tuple_processors = cast(
2057 "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors
2058 )
2059
2060 new_processors: Dict[str, _BindProcessorType[Any]] = {}
2061
2062 replacement_expressions: Dict[str, Any] = {}
2063 to_update_sets: Dict[str, Any] = {}
2064
2065 # notes:
2066 # *unescaped* parameter names in:
2067 # self.bind_names, self.binds, self._bind_processors, self.positiontup
2068 #
2069 # *escaped* parameter names in:
2070 # construct_params(), replacement_expressions
2071
2072 numeric_positiontup: Optional[List[str]] = None
2073
2074 if self.positional and pre_expanded_positiontup is not None:
2075 names: Iterable[str] = pre_expanded_positiontup
2076 if self._numeric_binds:
2077 numeric_positiontup = []
2078 else:
2079 names = self.bind_names.values()
2080
2081 ebn = self.escaped_bind_names
2082 for name in names:
2083 escaped_name = ebn.get(name, name) if ebn else name
2084 parameter = self.binds[name]
2085
2086 if parameter in self.literal_execute_params:
2087 if escaped_name not in replacement_expressions:
2088 replacement_expressions[escaped_name] = (
2089 self.render_literal_bindparam(
2090 parameter,
2091 render_literal_value=parameters.pop(escaped_name),
2092 )
2093 )
2094 continue
2095
2096 if parameter in self.post_compile_params:
2097 if escaped_name in replacement_expressions:
2098 to_update = to_update_sets[escaped_name]
2099 values = None
2100 else:
2101 # we are removing the parameter from parameters
2102 # because it is a list value, which is not expected by
2103 # TypeEngine objects that would otherwise be asked to
2104 # process it. the single name is being replaced with
2105 # individual numbered parameters for each value in the
2106 # param.
2107 #
2108 # note we are also inserting *escaped* parameter names
2109 # into the given dictionary. default dialect will
2110 # use these param names directly as they will not be
2111 # in the escaped_bind_names dictionary.
2112 values = parameters.pop(name)
2113
2114 leep_res = self._literal_execute_expanding_parameter(
2115 escaped_name, parameter, values
2116 )
2117 (to_update, replacement_expr) = leep_res
2118
2119 to_update_sets[escaped_name] = to_update
2120 replacement_expressions[escaped_name] = replacement_expr
2121
2122 if not parameter.literal_execute:
2123 parameters.update(to_update)
2124 if parameter.type._is_tuple_type:
2125 assert values is not None
2126 new_processors.update(
2127 (
2128 "%s_%s_%s" % (name, i, j),
2129 tuple_processors[name][j - 1],
2130 )
2131 for i, tuple_element in enumerate(values, 1)
2132 for j, _ in enumerate(tuple_element, 1)
2133 if name in tuple_processors
2134 and tuple_processors[name][j - 1] is not None
2135 )
2136 else:
2137 new_processors.update(
2138 (key, single_processors[name])
2139 for key, _ in to_update
2140 if name in single_processors
2141 )
2142 if numeric_positiontup is not None:
2143 numeric_positiontup.extend(
2144 name for name, _ in to_update
2145 )
2146 elif new_positiontup is not None:
2147 # to_update has escaped names, but that's ok since
2148 # these are new names, that aren't in the
2149 # escaped_bind_names dict.
2150 new_positiontup.extend(name for name, _ in to_update)
2151 expanded_parameters[name] = [
2152 expand_key for expand_key, _ in to_update
2153 ]
2154 elif new_positiontup is not None:
2155 new_positiontup.append(name)
2156
2157 def process_expanding(m):
2158 key = m.group(1)
2159 expr = replacement_expressions[key]
2160
2161 # if POSTCOMPILE included a bind_expression, render that
2162 # around each element
2163 if m.group(2):
2164 tok = m.group(2).split("~~")
2165 be_left, be_right = tok[1], tok[3]
2166 expr = ", ".join(
2167 "%s%s%s" % (be_left, exp, be_right)
2168 for exp in expr.split(", ")
2169 )
2170 return expr
2171
2172 statement = re.sub(
2173 self._post_compile_pattern, process_expanding, pre_expanded_string
2174 )
2175
2176 if numeric_positiontup is not None:
2177 assert new_positiontup is not None
2178 param_pos = {
2179 key: f"{self._numeric_binds_identifier_char}{num}"
2180 for num, key in enumerate(
2181 numeric_positiontup, self.next_numeric_pos
2182 )
2183 }
2184 # Can't use format here since % chars are not escaped.
2185 statement = self._pyformat_pattern.sub(
2186 lambda m: param_pos[m.group(1)], statement
2187 )
2188 new_positiontup.extend(numeric_positiontup)
2189
2190 expanded_state = ExpandedState(
2191 statement,
2192 parameters,
2193 new_processors,
2194 new_positiontup,
2195 expanded_parameters,
2196 )
2197
2198 if _populate_self:
2199 # this is for the "render_postcompile" flag, which is not
2200 # otherwise used internally and is for end-user debugging and
2201 # special use cases.
2202 self._pre_expanded_string = pre_expanded_string
2203 self._pre_expanded_positiontup = pre_expanded_positiontup
2204 self.string = expanded_state.statement
2205 self.positiontup = (
2206 list(expanded_state.positiontup or ())
2207 if self.positional
2208 else None
2209 )
2210 self._post_compile_expanded_state = expanded_state
2211
2212 return expanded_state
2213
2214 @util.preload_module("sqlalchemy.engine.cursor")
2215 def _create_result_map(self):
2216 """utility method used for unit tests only."""
2217 cursor = util.preloaded.engine_cursor
2218 return cursor.CursorResultMetaData._create_description_match_map(
2219 self._result_columns
2220 )
2221
2222 # assigned by crud.py for insert/update statements
2223 _get_bind_name_for_col: _BindNameForColProtocol
2224
2225 @util.memoized_property
2226 def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
2227 getter = self._get_bind_name_for_col
2228 return getter
2229
2230 @util.memoized_property
2231 @util.preload_module("sqlalchemy.engine.result")
2232 def _inserted_primary_key_from_lastrowid_getter(self):
2233 result = util.preloaded.engine_result
2234
2235 param_key_getter = self._within_exec_param_key_getter
2236
2237 assert self.compile_state is not None
2238 statement = self.compile_state.statement
2239
2240 if TYPE_CHECKING:
2241 assert isinstance(statement, Insert)
2242
2243 table = statement.table
2244
2245 getters = [
2246 (operator.methodcaller("get", param_key_getter(col), None), col)
2247 for col in table.primary_key
2248 ]
2249
2250 autoinc_getter = None
2251 autoinc_col = table._autoincrement_column
2252 if autoinc_col is not None:
2253 # apply type post processors to the lastrowid
2254 lastrowid_processor = autoinc_col.type._cached_result_processor(
2255 self.dialect, None
2256 )
2257 autoinc_key = param_key_getter(autoinc_col)
2258
2259 # if a bind value is present for the autoincrement column
2260 # in the parameters, we need to do the logic dictated by
2261 # #7998; honor a non-None user-passed parameter over lastrowid.
2262 # previously in the 1.4 series we weren't fetching lastrowid
2263 # at all if the key were present in the parameters
2264 if autoinc_key in self.binds:
2265
2266 def _autoinc_getter(lastrowid, parameters):
2267 param_value = parameters.get(autoinc_key, lastrowid)
2268 if param_value is not None:
2269 # they supplied non-None parameter, use that.
2270 # SQLite at least is observed to return the wrong
2271 # cursor.lastrowid for INSERT..ON CONFLICT so it
2272 # can't be used in all cases
2273 return param_value
2274 else:
2275 # use lastrowid
2276 return lastrowid
2277
2278 # work around mypy https://github.com/python/mypy/issues/14027
2279 autoinc_getter = _autoinc_getter
2280
2281 else:
2282 lastrowid_processor = None
2283
2284 row_fn = result.result_tuple([col.key for col in table.primary_key])
2285
2286 def get(lastrowid, parameters):
2287 """given cursor.lastrowid value and the parameters used for INSERT,
2288 return a "row" that represents the primary key, either by
2289 using the "lastrowid" or by extracting values from the parameters
2290 that were sent along with the INSERT.
2291
2292 """
2293 if lastrowid_processor is not None:
2294 lastrowid = lastrowid_processor(lastrowid)
2295
2296 if lastrowid is None:
2297 return row_fn(getter(parameters) for getter, col in getters)
2298 else:
2299 return row_fn(
2300 (
2301 (
2302 autoinc_getter(lastrowid, parameters)
2303 if autoinc_getter is not None
2304 else lastrowid
2305 )
2306 if col is autoinc_col
2307 else getter(parameters)
2308 )
2309 for getter, col in getters
2310 )
2311
2312 return get
2313
2314 @util.memoized_property
2315 @util.preload_module("sqlalchemy.engine.result")
2316 def _inserted_primary_key_from_returning_getter(self):
2317 result = util.preloaded.engine_result
2318
2319 assert self.compile_state is not None
2320 statement = self.compile_state.statement
2321
2322 if TYPE_CHECKING:
2323 assert isinstance(statement, Insert)
2324
2325 param_key_getter = self._within_exec_param_key_getter
2326 table = statement.table
2327
2328 returning = self.implicit_returning
2329 assert returning is not None
2330 ret = {col: idx for idx, col in enumerate(returning)}
2331
2332 getters = cast(
2333 "List[Tuple[Callable[[Any], Any], bool]]",
2334 [
2335 (
2336 (operator.itemgetter(ret[col]), True)
2337 if col in ret
2338 else (
2339 operator.methodcaller(
2340 "get", param_key_getter(col), None
2341 ),
2342 False,
2343 )
2344 )
2345 for col in table.primary_key
2346 ],
2347 )
2348
2349 row_fn = result.result_tuple([col.key for col in table.primary_key])
2350
2351 def get(row, parameters):
2352 return row_fn(
2353 getter(row) if use_row else getter(parameters)
2354 for getter, use_row in getters
2355 )
2356
2357 return get
2358
2359 def default_from(self) -> str:
2360 """Called when a SELECT statement has no froms, and no FROM clause is
2361 to be appended.
2362
2363 Gives Oracle Database a chance to tack on a ``FROM DUAL`` to the string
2364 output.
2365
2366 """
2367 return ""
2368
2369 def visit_override_binds(self, override_binds, **kw):
2370 """SQL compile the nested element of an _OverrideBinds with
2371 bindparams swapped out.
2372
2373 The _OverrideBinds is not normally expected to be compiled; it
2374 is meant to be used when an already cached statement is to be used,
2375 the compilation was already performed, and only the bound params should
2376 be swapped in at execution time.
2377
2378 However, there are test cases that exericise this object, and
2379 additionally the ORM subquery loader is known to feed in expressions
2380 which include this construct into new queries (discovered in #11173),
2381 so it has to do the right thing at compile time as well.
2382
2383 """
2384
2385 # get SQL text first
2386 sqltext = override_binds.element._compiler_dispatch(self, **kw)
2387
2388 # for a test compile that is not for caching, change binds after the
2389 # fact. note that we don't try to
2390 # swap the bindparam as we compile, because our element may be
2391 # elsewhere in the statement already (e.g. a subquery or perhaps a
2392 # CTE) and was already visited / compiled. See
2393 # test_relationship_criteria.py ->
2394 # test_selectinload_local_criteria_subquery
2395 for k in override_binds.translate:
2396 if k not in self.binds:
2397 continue
2398 bp = self.binds[k]
2399
2400 # so this would work, just change the value of bp in place.
2401 # but we dont want to mutate things outside.
2402 # bp.value = override_binds.translate[bp.key]
2403 # continue
2404
2405 # instead, need to replace bp with new_bp or otherwise accommodate
2406 # in all internal collections
2407 new_bp = bp._with_value(
2408 override_binds.translate[bp.key],
2409 maintain_key=True,
2410 required=False,
2411 )
2412
2413 name = self.bind_names[bp]
2414 self.binds[k] = self.binds[name] = new_bp
2415 self.bind_names[new_bp] = name
2416 self.bind_names.pop(bp, None)
2417
2418 if bp in self.post_compile_params:
2419 self.post_compile_params |= {new_bp}
2420 if bp in self.literal_execute_params:
2421 self.literal_execute_params |= {new_bp}
2422
2423 ckbm_tuple = self._cache_key_bind_match
2424 if ckbm_tuple:
2425 ckbm, cksm = ckbm_tuple
2426 for bp in bp._cloned_set:
2427 if bp.key in cksm:
2428 cb = cksm[bp.key]
2429 ckbm[cb].append(new_bp)
2430
2431 return sqltext
2432
2433 def visit_grouping(self, grouping, asfrom=False, **kwargs):
2434 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
2435
2436 def visit_select_statement_grouping(self, grouping, **kwargs):
2437 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
2438
2439 def visit_label_reference(
2440 self, element, within_columns_clause=False, **kwargs
2441 ):
2442 if self.stack and self.dialect.supports_simple_order_by_label:
2443 try:
2444 compile_state = cast(
2445 "Union[SelectState, CompoundSelectState]",
2446 self.stack[-1]["compile_state"],
2447 )
2448 except KeyError as ke:
2449 raise exc.CompileError(
2450 "Can't resolve label reference for ORDER BY / "
2451 "GROUP BY / DISTINCT etc."
2452 ) from ke
2453
2454 (
2455 with_cols,
2456 only_froms,
2457 only_cols,
2458 ) = compile_state._label_resolve_dict
2459 if within_columns_clause:
2460 resolve_dict = only_froms
2461 else:
2462 resolve_dict = only_cols
2463
2464 # this can be None in the case that a _label_reference()
2465 # were subject to a replacement operation, in which case
2466 # the replacement of the Label element may have changed
2467 # to something else like a ColumnClause expression.
2468 order_by_elem = element.element._order_by_label_element
2469
2470 if (
2471 order_by_elem is not None
2472 and order_by_elem.name in resolve_dict
2473 and order_by_elem.shares_lineage(
2474 resolve_dict[order_by_elem.name]
2475 )
2476 ):
2477 kwargs["render_label_as_label"] = (
2478 element.element._order_by_label_element
2479 )
2480 return self.process(
2481 element.element,
2482 within_columns_clause=within_columns_clause,
2483 **kwargs,
2484 )
2485
2486 def visit_textual_label_reference(
2487 self, element, within_columns_clause=False, **kwargs
2488 ):
2489 if not self.stack:
2490 # compiling the element outside of the context of a SELECT
2491 return self.process(element._text_clause)
2492
2493 try:
2494 compile_state = cast(
2495 "Union[SelectState, CompoundSelectState]",
2496 self.stack[-1]["compile_state"],
2497 )
2498 except KeyError as ke:
2499 coercions._no_text_coercion(
2500 element.element,
2501 extra=(
2502 "Can't resolve label reference for ORDER BY / "
2503 "GROUP BY / DISTINCT etc."
2504 ),
2505 exc_cls=exc.CompileError,
2506 err=ke,
2507 )
2508
2509 with_cols, only_froms, only_cols = compile_state._label_resolve_dict
2510 try:
2511 if within_columns_clause:
2512 col = only_froms[element.element]
2513 else:
2514 col = with_cols[element.element]
2515 except KeyError as err:
2516 coercions._no_text_coercion(
2517 element.element,
2518 extra=(
2519 "Can't resolve label reference for ORDER BY / "
2520 "GROUP BY / DISTINCT etc."
2521 ),
2522 exc_cls=exc.CompileError,
2523 err=err,
2524 )
2525 else:
2526 kwargs["render_label_as_label"] = col
2527 return self.process(
2528 col, within_columns_clause=within_columns_clause, **kwargs
2529 )
2530
2531 def visit_label(
2532 self,
2533 label,
2534 add_to_result_map=None,
2535 within_label_clause=False,
2536 within_columns_clause=False,
2537 render_label_as_label=None,
2538 result_map_targets=(),
2539 **kw,
2540 ):
2541 # only render labels within the columns clause
2542 # or ORDER BY clause of a select. dialect-specific compilers
2543 # can modify this behavior.
2544 render_label_with_as = (
2545 within_columns_clause and not within_label_clause
2546 )
2547 render_label_only = render_label_as_label is label
2548
2549 if render_label_only or render_label_with_as:
2550 if isinstance(label.name, elements._truncated_label):
2551 labelname = self._truncated_identifier("colident", label.name)
2552 else:
2553 labelname = label.name
2554
2555 if render_label_with_as:
2556 if add_to_result_map is not None:
2557 add_to_result_map(
2558 labelname,
2559 label.name,
2560 (label, labelname) + label._alt_names + result_map_targets,
2561 label.type,
2562 )
2563 return (
2564 label.element._compiler_dispatch(
2565 self,
2566 within_columns_clause=True,
2567 within_label_clause=True,
2568 **kw,
2569 )
2570 + OPERATORS[operators.as_]
2571 + self.preparer.format_label(label, labelname)
2572 )
2573 elif render_label_only:
2574 return self.preparer.format_label(label, labelname)
2575 else:
2576 return label.element._compiler_dispatch(
2577 self, within_columns_clause=False, **kw
2578 )
2579
2580 def _fallback_column_name(self, column):
2581 raise exc.CompileError(
2582 "Cannot compile Column object until its 'name' is assigned."
2583 )
2584
2585 def visit_lambda_element(self, element, **kw):
2586 sql_element = element._resolved
2587 return self.process(sql_element, **kw)
2588
2589 def visit_column(
2590 self,
2591 column: ColumnClause[Any],
2592 add_to_result_map: Optional[_ResultMapAppender] = None,
2593 include_table: bool = True,
2594 result_map_targets: Tuple[Any, ...] = (),
2595 ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
2596 **kwargs: Any,
2597 ) -> str:
2598 name = orig_name = column.name
2599 if name is None:
2600 name = self._fallback_column_name(column)
2601
2602 is_literal = column.is_literal
2603 if not is_literal and isinstance(name, elements._truncated_label):
2604 name = self._truncated_identifier("colident", name)
2605
2606 if add_to_result_map is not None:
2607 targets = (column, name, column.key) + result_map_targets
2608 if column._tq_label:
2609 targets += (column._tq_label,)
2610
2611 add_to_result_map(name, orig_name, targets, column.type)
2612
2613 if is_literal:
2614 # note we are not currently accommodating for
2615 # literal_column(quoted_name('ident', True)) here
2616 name = self.escape_literal_column(name)
2617 else:
2618 name = self.preparer.quote(name)
2619 table = column.table
2620 if table is None or not include_table or not table.named_with_column:
2621 return name
2622 else:
2623 effective_schema = self.preparer.schema_for_object(table)
2624
2625 if effective_schema:
2626 schema_prefix = (
2627 self.preparer.quote_schema(effective_schema) + "."
2628 )
2629 else:
2630 schema_prefix = ""
2631
2632 if TYPE_CHECKING:
2633 assert isinstance(table, NamedFromClause)
2634 tablename = table.name
2635
2636 if (
2637 not effective_schema
2638 and ambiguous_table_name_map
2639 and tablename in ambiguous_table_name_map
2640 ):
2641 tablename = ambiguous_table_name_map[tablename]
2642
2643 if isinstance(tablename, elements._truncated_label):
2644 tablename = self._truncated_identifier("alias", tablename)
2645
2646 return schema_prefix + self.preparer.quote(tablename) + "." + name
2647
2648 def visit_collation(self, element, **kw):
2649 return self.preparer.format_collation(element.collation)
2650
2651 def visit_fromclause(self, fromclause, **kwargs):
2652 return fromclause.name
2653
2654 def visit_index(self, index, **kwargs):
2655 return index.name
2656
2657 def visit_typeclause(self, typeclause, **kw):
2658 kw["type_expression"] = typeclause
2659 kw["identifier_preparer"] = self.preparer
2660 return self.dialect.type_compiler_instance.process(
2661 typeclause.type, **kw
2662 )
2663
2664 def post_process_text(self, text):
2665 if self.preparer._double_percents:
2666 text = text.replace("%", "%%")
2667 return text
2668
2669 def escape_literal_column(self, text):
2670 if self.preparer._double_percents:
2671 text = text.replace("%", "%%")
2672 return text
2673
2674 def visit_textclause(self, textclause, add_to_result_map=None, **kw):
2675 def do_bindparam(m):
2676 name = m.group(1)
2677 if name in textclause._bindparams:
2678 return self.process(textclause._bindparams[name], **kw)
2679 else:
2680 return self.bindparam_string(name, **kw)
2681
2682 if not self.stack:
2683 self.isplaintext = True
2684
2685 if add_to_result_map:
2686 # text() object is present in the columns clause of a
2687 # select(). Add a no-name entry to the result map so that
2688 # row[text()] produces a result
2689 add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
2690
2691 # un-escape any \:params
2692 return BIND_PARAMS_ESC.sub(
2693 lambda m: m.group(1),
2694 BIND_PARAMS.sub(
2695 do_bindparam, self.post_process_text(textclause.text)
2696 ),
2697 )
2698
2699 def visit_textual_select(
2700 self, taf, compound_index=None, asfrom=False, **kw
2701 ):
2702 toplevel = not self.stack
2703 entry = self._default_stack_entry if toplevel else self.stack[-1]
2704
2705 new_entry: _CompilerStackEntry = {
2706 "correlate_froms": set(),
2707 "asfrom_froms": set(),
2708 "selectable": taf,
2709 }
2710 self.stack.append(new_entry)
2711
2712 if taf._independent_ctes:
2713 self._dispatch_independent_ctes(taf, kw)
2714
2715 populate_result_map = (
2716 toplevel
2717 or (
2718 compound_index == 0
2719 and entry.get("need_result_map_for_compound", False)
2720 )
2721 or entry.get("need_result_map_for_nested", False)
2722 )
2723
2724 if populate_result_map:
2725 self._ordered_columns = self._textual_ordered_columns = (
2726 taf.positional
2727 )
2728
2729 # enable looser result column matching when the SQL text links to
2730 # Column objects by name only
2731 self._loose_column_name_matching = not taf.positional and bool(
2732 taf.column_args
2733 )
2734
2735 for c in taf.column_args:
2736 self.process(
2737 c,
2738 within_columns_clause=True,
2739 add_to_result_map=self._add_to_result_map,
2740 )
2741
2742 text = self.process(taf.element, **kw)
2743 if self.ctes:
2744 nesting_level = len(self.stack) if not toplevel else None
2745 text = self._render_cte_clause(nesting_level=nesting_level) + text
2746
2747 self.stack.pop(-1)
2748
2749 return text
2750
2751 def visit_null(self, expr: Null, **kw: Any) -> str:
2752 return "NULL"
2753
2754 def visit_true(self, expr: True_, **kw: Any) -> str:
2755 if self.dialect.supports_native_boolean:
2756 return "true"
2757 else:
2758 return "1"
2759
2760 def visit_false(self, expr: False_, **kw: Any) -> str:
2761 if self.dialect.supports_native_boolean:
2762 return "false"
2763 else:
2764 return "0"
2765
2766 def _generate_delimited_list(self, elements, separator, **kw):
2767 return separator.join(
2768 s
2769 for s in (c._compiler_dispatch(self, **kw) for c in elements)
2770 if s
2771 )
2772
2773 def _generate_delimited_and_list(self, clauses, **kw):
2774 lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean(
2775 operators.and_,
2776 elements.True_._singleton,
2777 elements.False_._singleton,
2778 clauses,
2779 )
2780 if lcc == 1:
2781 return clauses[0]._compiler_dispatch(self, **kw)
2782 else:
2783 separator = OPERATORS[operators.and_]
2784 return separator.join(
2785 s
2786 for s in (c._compiler_dispatch(self, **kw) for c in clauses)
2787 if s
2788 )
2789
2790 def visit_tuple(self, clauselist, **kw):
2791 return "(%s)" % self.visit_clauselist(clauselist, **kw)
2792
2793 def visit_clauselist(self, clauselist, **kw):
2794 sep = clauselist.operator
2795 if sep is None:
2796 sep = " "
2797 else:
2798 sep = OPERATORS[clauselist.operator]
2799
2800 return self._generate_delimited_list(clauselist.clauses, sep, **kw)
2801
2802 def visit_expression_clauselist(self, clauselist, **kw):
2803 operator_ = clauselist.operator
2804
2805 disp = self._get_operator_dispatch(
2806 operator_, "expression_clauselist", None
2807 )
2808 if disp:
2809 return disp(clauselist, operator_, **kw)
2810
2811 try:
2812 opstring = OPERATORS[operator_]
2813 except KeyError as err:
2814 raise exc.UnsupportedCompilationError(self, operator_) from err
2815 else:
2816 kw["_in_operator_expression"] = True
2817 return self._generate_delimited_list(
2818 clauselist.clauses, opstring, **kw
2819 )
2820
2821 def visit_case(self, clause, **kwargs):
2822 x = "CASE "
2823 if clause.value is not None:
2824 x += clause.value._compiler_dispatch(self, **kwargs) + " "
2825 for cond, result in clause.whens:
2826 x += (
2827 "WHEN "
2828 + cond._compiler_dispatch(self, **kwargs)
2829 + " THEN "
2830 + result._compiler_dispatch(self, **kwargs)
2831 + " "
2832 )
2833 if clause.else_ is not None:
2834 x += (
2835 "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
2836 )
2837 x += "END"
2838 return x
2839
2840 def visit_type_coerce(self, type_coerce, **kw):
2841 return type_coerce.typed_expression._compiler_dispatch(self, **kw)
2842
2843 def visit_cast(self, cast, **kwargs):
2844 type_clause = cast.typeclause._compiler_dispatch(self, **kwargs)
2845 match = re.match("(.*)( COLLATE .*)", type_clause)
2846 return "CAST(%s AS %s)%s" % (
2847 cast.clause._compiler_dispatch(self, **kwargs),
2848 match.group(1) if match else type_clause,
2849 match.group(2) if match else "",
2850 )
2851
2852 def _format_frame_clause(self, range_, **kw):
2853 return "%s AND %s" % (
2854 (
2855 "UNBOUNDED PRECEDING"
2856 if range_[0] is elements.RANGE_UNBOUNDED
2857 else (
2858 "CURRENT ROW"
2859 if range_[0] is elements.RANGE_CURRENT
2860 else (
2861 "%s PRECEDING"
2862 % (
2863 self.process(
2864 elements.literal(abs(range_[0])), **kw
2865 ),
2866 )
2867 if range_[0] < 0
2868 else "%s FOLLOWING"
2869 % (self.process(elements.literal(range_[0]), **kw),)
2870 )
2871 )
2872 ),
2873 (
2874 "UNBOUNDED FOLLOWING"
2875 if range_[1] is elements.RANGE_UNBOUNDED
2876 else (
2877 "CURRENT ROW"
2878 if range_[1] is elements.RANGE_CURRENT
2879 else (
2880 "%s PRECEDING"
2881 % (
2882 self.process(
2883 elements.literal(abs(range_[1])), **kw
2884 ),
2885 )
2886 if range_[1] < 0
2887 else "%s FOLLOWING"
2888 % (self.process(elements.literal(range_[1]), **kw),)
2889 )
2890 )
2891 ),
2892 )
2893
2894 def visit_over(self, over, **kwargs):
2895 text = over.element._compiler_dispatch(self, **kwargs)
2896 if over.range_ is not None:
2897 range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
2898 over.range_, **kwargs
2899 )
2900 elif over.rows is not None:
2901 range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
2902 over.rows, **kwargs
2903 )
2904 elif over.groups is not None:
2905 range_ = "GROUPS BETWEEN %s" % self._format_frame_clause(
2906 over.groups, **kwargs
2907 )
2908 else:
2909 range_ = None
2910
2911 return "%s OVER (%s)" % (
2912 text,
2913 " ".join(
2914 [
2915 "%s BY %s"
2916 % (word, clause._compiler_dispatch(self, **kwargs))
2917 for word, clause in (
2918 ("PARTITION", over.partition_by),
2919 ("ORDER", over.order_by),
2920 )
2921 if clause is not None and len(clause)
2922 ]
2923 + ([range_] if range_ else [])
2924 ),
2925 )
2926
2927 def visit_withingroup(self, withingroup, **kwargs):
2928 return "%s WITHIN GROUP (ORDER BY %s)" % (
2929 withingroup.element._compiler_dispatch(self, **kwargs),
2930 withingroup.order_by._compiler_dispatch(self, **kwargs),
2931 )
2932
2933 def visit_funcfilter(self, funcfilter, **kwargs):
2934 return "%s FILTER (WHERE %s)" % (
2935 funcfilter.func._compiler_dispatch(self, **kwargs),
2936 funcfilter.criterion._compiler_dispatch(self, **kwargs),
2937 )
2938
2939 def visit_extract(self, extract, **kwargs):
2940 field = self.extract_map.get(extract.field, extract.field)
2941 return "EXTRACT(%s FROM %s)" % (
2942 field,
2943 extract.expr._compiler_dispatch(self, **kwargs),
2944 )
2945
2946 def visit_scalar_function_column(self, element, **kw):
2947 compiled_fn = self.visit_function(element.fn, **kw)
2948 compiled_col = self.visit_column(element, **kw)
2949 return "(%s).%s" % (compiled_fn, compiled_col)
2950
2951 def visit_function(
2952 self,
2953 func: Function[Any],
2954 add_to_result_map: Optional[_ResultMapAppender] = None,
2955 **kwargs: Any,
2956 ) -> str:
2957 if add_to_result_map is not None:
2958 add_to_result_map(func.name, func.name, (func.name,), func.type)
2959
2960 disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
2961
2962 text: str
2963
2964 if disp:
2965 text = disp(func, **kwargs)
2966 else:
2967 name = FUNCTIONS.get(func._deannotate().__class__, None)
2968 if name:
2969 if func._has_args:
2970 name += "%(expr)s"
2971 else:
2972 name = func.name
2973 name = (
2974 self.preparer.quote(name)
2975 if self.preparer._requires_quotes_illegal_chars(name)
2976 or isinstance(name, elements.quoted_name)
2977 else name
2978 )
2979 name = name + "%(expr)s"
2980 text = ".".join(
2981 [
2982 (
2983 self.preparer.quote(tok)
2984 if self.preparer._requires_quotes_illegal_chars(tok)
2985 or isinstance(name, elements.quoted_name)
2986 else tok
2987 )
2988 for tok in func.packagenames
2989 ]
2990 + [name]
2991 ) % {"expr": self.function_argspec(func, **kwargs)}
2992
2993 if func._with_ordinality:
2994 text += " WITH ORDINALITY"
2995 return text
2996
2997 def visit_next_value_func(self, next_value, **kw):
2998 return self.visit_sequence(next_value.sequence)
2999
3000 def visit_sequence(self, sequence, **kw):
3001 raise NotImplementedError(
3002 "Dialect '%s' does not support sequence increments."
3003 % self.dialect.name
3004 )
3005
3006 def function_argspec(self, func: Function[Any], **kwargs: Any) -> str:
3007 return func.clause_expr._compiler_dispatch(self, **kwargs)
3008
3009 def visit_compound_select(
3010 self, cs, asfrom=False, compound_index=None, **kwargs
3011 ):
3012 toplevel = not self.stack
3013
3014 compile_state = cs._compile_state_factory(cs, self, **kwargs)
3015
3016 if toplevel and not self.compile_state:
3017 self.compile_state = compile_state
3018
3019 compound_stmt = compile_state.statement
3020
3021 entry = self._default_stack_entry if toplevel else self.stack[-1]
3022 need_result_map = toplevel or (
3023 not compound_index
3024 and entry.get("need_result_map_for_compound", False)
3025 )
3026
3027 # indicates there is already a CompoundSelect in play
3028 if compound_index == 0:
3029 entry["select_0"] = cs
3030
3031 self.stack.append(
3032 {
3033 "correlate_froms": entry["correlate_froms"],
3034 "asfrom_froms": entry["asfrom_froms"],
3035 "selectable": cs,
3036 "compile_state": compile_state,
3037 "need_result_map_for_compound": need_result_map,
3038 }
3039 )
3040
3041 if compound_stmt._independent_ctes:
3042 self._dispatch_independent_ctes(compound_stmt, kwargs)
3043
3044 keyword = self.compound_keywords[cs.keyword]
3045
3046 text = (" " + keyword + " ").join(
3047 (
3048 c._compiler_dispatch(
3049 self, asfrom=asfrom, compound_index=i, **kwargs
3050 )
3051 for i, c in enumerate(cs.selects)
3052 )
3053 )
3054
3055 kwargs["include_table"] = False
3056 text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
3057 text += self.order_by_clause(cs, **kwargs)
3058 if cs._has_row_limiting_clause:
3059 text += self._row_limit_clause(cs, **kwargs)
3060
3061 if self.ctes:
3062 nesting_level = len(self.stack) if not toplevel else None
3063 text = (
3064 self._render_cte_clause(
3065 nesting_level=nesting_level,
3066 include_following_stack=True,
3067 )
3068 + text
3069 )
3070
3071 self.stack.pop(-1)
3072 return text
3073
3074 def _row_limit_clause(self, cs, **kwargs):
3075 if cs._fetch_clause is not None:
3076 return self.fetch_clause(cs, **kwargs)
3077 else:
3078 return self.limit_clause(cs, **kwargs)
3079
3080 def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
3081 attrname = "visit_%s_%s%s" % (
3082 operator_.__name__,
3083 qualifier1,
3084 "_" + qualifier2 if qualifier2 else "",
3085 )
3086 return getattr(self, attrname, None)
3087
3088 def visit_unary(
3089 self, unary, add_to_result_map=None, result_map_targets=(), **kw
3090 ):
3091 if add_to_result_map is not None:
3092 result_map_targets += (unary,)
3093 kw["add_to_result_map"] = add_to_result_map
3094 kw["result_map_targets"] = result_map_targets
3095
3096 if unary.operator:
3097 if unary.modifier:
3098 raise exc.CompileError(
3099 "Unary expression does not support operator "
3100 "and modifier simultaneously"
3101 )
3102 disp = self._get_operator_dispatch(
3103 unary.operator, "unary", "operator"
3104 )
3105 if disp:
3106 return disp(unary, unary.operator, **kw)
3107 else:
3108 return self._generate_generic_unary_operator(
3109 unary, OPERATORS[unary.operator], **kw
3110 )
3111 elif unary.modifier:
3112 disp = self._get_operator_dispatch(
3113 unary.modifier, "unary", "modifier"
3114 )
3115 if disp:
3116 return disp(unary, unary.modifier, **kw)
3117 else:
3118 return self._generate_generic_unary_modifier(
3119 unary, OPERATORS[unary.modifier], **kw
3120 )
3121 else:
3122 raise exc.CompileError(
3123 "Unary expression has no operator or modifier"
3124 )
3125
3126 def visit_truediv_binary(self, binary, operator, **kw):
3127 if self.dialect.div_is_floordiv:
3128 return (
3129 self.process(binary.left, **kw)
3130 + " / "
3131 # TODO: would need a fast cast again here,
3132 # unless we want to use an implicit cast like "+ 0.0"
3133 + self.process(
3134 elements.Cast(
3135 binary.right,
3136 (
3137 binary.right.type
3138 if binary.right.type._type_affinity
3139 is sqltypes.Numeric
3140 else sqltypes.Numeric()
3141 ),
3142 ),
3143 **kw,
3144 )
3145 )
3146 else:
3147 return (
3148 self.process(binary.left, **kw)
3149 + " / "
3150 + self.process(binary.right, **kw)
3151 )
3152
3153 def visit_floordiv_binary(self, binary, operator, **kw):
3154 if (
3155 self.dialect.div_is_floordiv
3156 and binary.right.type._type_affinity is sqltypes.Integer
3157 ):
3158 return (
3159 self.process(binary.left, **kw)
3160 + " / "
3161 + self.process(binary.right, **kw)
3162 )
3163 else:
3164 return "FLOOR(%s)" % (
3165 self.process(binary.left, **kw)
3166 + " / "
3167 + self.process(binary.right, **kw)
3168 )
3169
3170 def visit_is_true_unary_operator(self, element, operator, **kw):
3171 if (
3172 element._is_implicitly_boolean
3173 or self.dialect.supports_native_boolean
3174 ):
3175 return self.process(element.element, **kw)
3176 else:
3177 return "%s = 1" % self.process(element.element, **kw)
3178
3179 def visit_is_false_unary_operator(self, element, operator, **kw):
3180 if (
3181 element._is_implicitly_boolean
3182 or self.dialect.supports_native_boolean
3183 ):
3184 return "NOT %s" % self.process(element.element, **kw)
3185 else:
3186 return "%s = 0" % self.process(element.element, **kw)
3187
3188 def visit_not_match_op_binary(self, binary, operator, **kw):
3189 return "NOT %s" % self.visit_binary(
3190 binary, override_operator=operators.match_op
3191 )
3192
3193 def visit_not_in_op_binary(self, binary, operator, **kw):
3194 # The brackets are required in the NOT IN operation because the empty
3195 # case is handled using the form "(col NOT IN (null) OR 1 = 1)".
3196 # The presence of the OR makes the brackets required.
3197 return "(%s)" % self._generate_generic_binary(
3198 binary, OPERATORS[operator], **kw
3199 )
3200
3201 def visit_empty_set_op_expr(self, type_, expand_op, **kw):
3202 if expand_op is operators.not_in_op:
3203 if len(type_) > 1:
3204 return "(%s)) OR (1 = 1" % (
3205 ", ".join("NULL" for element in type_)
3206 )
3207 else:
3208 return "NULL) OR (1 = 1"
3209 elif expand_op is operators.in_op:
3210 if len(type_) > 1:
3211 return "(%s)) AND (1 != 1" % (
3212 ", ".join("NULL" for element in type_)
3213 )
3214 else:
3215 return "NULL) AND (1 != 1"
3216 else:
3217 return self.visit_empty_set_expr(type_)
3218
3219 def visit_empty_set_expr(self, element_types, **kw):
3220 raise NotImplementedError(
3221 "Dialect '%s' does not support empty set expression."
3222 % self.dialect.name
3223 )
3224
3225 def _literal_execute_expanding_parameter_literal_binds(
3226 self, parameter, values, bind_expression_template=None
3227 ):
3228 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
3229
3230 if not values:
3231 # empty IN expression. note we don't need to use
3232 # bind_expression_template here because there are no
3233 # expressions to render.
3234
3235 if typ_dialect_impl._is_tuple_type:
3236 replacement_expression = (
3237 "VALUES " if self.dialect.tuple_in_values else ""
3238 ) + self.visit_empty_set_op_expr(
3239 parameter.type.types, parameter.expand_op
3240 )
3241
3242 else:
3243 replacement_expression = self.visit_empty_set_op_expr(
3244 [parameter.type], parameter.expand_op
3245 )
3246
3247 elif typ_dialect_impl._is_tuple_type or (
3248 typ_dialect_impl._isnull
3249 and isinstance(values[0], collections_abc.Sequence)
3250 and not isinstance(values[0], (str, bytes))
3251 ):
3252 if typ_dialect_impl._has_bind_expression:
3253 raise NotImplementedError(
3254 "bind_expression() on TupleType not supported with "
3255 "literal_binds"
3256 )
3257
3258 replacement_expression = (
3259 "VALUES " if self.dialect.tuple_in_values else ""
3260 ) + ", ".join(
3261 "(%s)"
3262 % (
3263 ", ".join(
3264 self.render_literal_value(value, param_type)
3265 for value, param_type in zip(
3266 tuple_element, parameter.type.types
3267 )
3268 )
3269 )
3270 for i, tuple_element in enumerate(values)
3271 )
3272 else:
3273 if bind_expression_template:
3274 post_compile_pattern = self._post_compile_pattern
3275 m = post_compile_pattern.search(bind_expression_template)
3276 assert m and m.group(
3277 2
3278 ), "unexpected format for expanding parameter"
3279
3280 tok = m.group(2).split("~~")
3281 be_left, be_right = tok[1], tok[3]
3282 replacement_expression = ", ".join(
3283 "%s%s%s"
3284 % (
3285 be_left,
3286 self.render_literal_value(value, parameter.type),
3287 be_right,
3288 )
3289 for value in values
3290 )
3291 else:
3292 replacement_expression = ", ".join(
3293 self.render_literal_value(value, parameter.type)
3294 for value in values
3295 )
3296
3297 return (), replacement_expression
3298
3299 def _literal_execute_expanding_parameter(self, name, parameter, values):
3300 if parameter.literal_execute:
3301 return self._literal_execute_expanding_parameter_literal_binds(
3302 parameter, values
3303 )
3304
3305 dialect = self.dialect
3306 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect)
3307
3308 if self._numeric_binds:
3309 bind_template = self.compilation_bindtemplate
3310 else:
3311 bind_template = self.bindtemplate
3312
3313 if (
3314 self.dialect._bind_typing_render_casts
3315 and typ_dialect_impl.render_bind_cast
3316 ):
3317
3318 def _render_bindtemplate(name):
3319 return self.render_bind_cast(
3320 parameter.type,
3321 typ_dialect_impl,
3322 bind_template % {"name": name},
3323 )
3324
3325 else:
3326
3327 def _render_bindtemplate(name):
3328 return bind_template % {"name": name}
3329
3330 if not values:
3331 to_update = []
3332 if typ_dialect_impl._is_tuple_type:
3333 replacement_expression = self.visit_empty_set_op_expr(
3334 parameter.type.types, parameter.expand_op
3335 )
3336 else:
3337 replacement_expression = self.visit_empty_set_op_expr(
3338 [parameter.type], parameter.expand_op
3339 )
3340
3341 elif typ_dialect_impl._is_tuple_type or (
3342 typ_dialect_impl._isnull
3343 and isinstance(values[0], collections_abc.Sequence)
3344 and not isinstance(values[0], (str, bytes))
3345 ):
3346 assert not typ_dialect_impl._is_array
3347 to_update = [
3348 ("%s_%s_%s" % (name, i, j), value)
3349 for i, tuple_element in enumerate(values, 1)
3350 for j, value in enumerate(tuple_element, 1)
3351 ]
3352
3353 replacement_expression = (
3354 "VALUES " if dialect.tuple_in_values else ""
3355 ) + ", ".join(
3356 "(%s)"
3357 % (
3358 ", ".join(
3359 _render_bindtemplate(
3360 to_update[i * len(tuple_element) + j][0]
3361 )
3362 for j, value in enumerate(tuple_element)
3363 )
3364 )
3365 for i, tuple_element in enumerate(values)
3366 )
3367 else:
3368 to_update = [
3369 ("%s_%s" % (name, i), value)
3370 for i, value in enumerate(values, 1)
3371 ]
3372 replacement_expression = ", ".join(
3373 _render_bindtemplate(key) for key, value in to_update
3374 )
3375
3376 return to_update, replacement_expression
3377
3378 def visit_binary(
3379 self,
3380 binary,
3381 override_operator=None,
3382 eager_grouping=False,
3383 from_linter=None,
3384 lateral_from_linter=None,
3385 **kw,
3386 ):
3387 if from_linter and operators.is_comparison(binary.operator):
3388 if lateral_from_linter is not None:
3389 enclosing_lateral = kw["enclosing_lateral"]
3390 lateral_from_linter.edges.update(
3391 itertools.product(
3392 _de_clone(
3393 binary.left._from_objects + [enclosing_lateral]
3394 ),
3395 _de_clone(
3396 binary.right._from_objects + [enclosing_lateral]
3397 ),
3398 )
3399 )
3400 else:
3401 from_linter.edges.update(
3402 itertools.product(
3403 _de_clone(binary.left._from_objects),
3404 _de_clone(binary.right._from_objects),
3405 )
3406 )
3407
3408 # don't allow "? = ?" to render
3409 if (
3410 self.ansi_bind_rules
3411 and isinstance(binary.left, elements.BindParameter)
3412 and isinstance(binary.right, elements.BindParameter)
3413 ):
3414 kw["literal_execute"] = True
3415
3416 operator_ = override_operator or binary.operator
3417 disp = self._get_operator_dispatch(operator_, "binary", None)
3418 if disp:
3419 return disp(binary, operator_, **kw)
3420 else:
3421 try:
3422 opstring = OPERATORS[operator_]
3423 except KeyError as err:
3424 raise exc.UnsupportedCompilationError(self, operator_) from err
3425 else:
3426 return self._generate_generic_binary(
3427 binary,
3428 opstring,
3429 from_linter=from_linter,
3430 lateral_from_linter=lateral_from_linter,
3431 **kw,
3432 )
3433
3434 def visit_function_as_comparison_op_binary(self, element, operator, **kw):
3435 return self.process(element.sql_function, **kw)
3436
3437 def visit_mod_binary(self, binary, operator, **kw):
3438 if self.preparer._double_percents:
3439 return (
3440 self.process(binary.left, **kw)
3441 + " %% "
3442 + self.process(binary.right, **kw)
3443 )
3444 else:
3445 return (
3446 self.process(binary.left, **kw)
3447 + " % "
3448 + self.process(binary.right, **kw)
3449 )
3450
3451 def visit_custom_op_binary(self, element, operator, **kw):
3452 kw["eager_grouping"] = operator.eager_grouping
3453 return self._generate_generic_binary(
3454 element,
3455 " " + self.escape_literal_column(operator.opstring) + " ",
3456 **kw,
3457 )
3458
3459 def visit_custom_op_unary_operator(self, element, operator, **kw):
3460 return self._generate_generic_unary_operator(
3461 element, self.escape_literal_column(operator.opstring) + " ", **kw
3462 )
3463
3464 def visit_custom_op_unary_modifier(self, element, operator, **kw):
3465 return self._generate_generic_unary_modifier(
3466 element, " " + self.escape_literal_column(operator.opstring), **kw
3467 )
3468
3469 def _generate_generic_binary(
3470 self,
3471 binary: BinaryExpression[Any],
3472 opstring: str,
3473 eager_grouping: bool = False,
3474 **kw: Any,
3475 ) -> str:
3476 _in_operator_expression = kw.get("_in_operator_expression", False)
3477
3478 kw["_in_operator_expression"] = True
3479 kw["_binary_op"] = binary.operator
3480 text = (
3481 binary.left._compiler_dispatch(
3482 self, eager_grouping=eager_grouping, **kw
3483 )
3484 + opstring
3485 + binary.right._compiler_dispatch(
3486 self, eager_grouping=eager_grouping, **kw
3487 )
3488 )
3489
3490 if _in_operator_expression and eager_grouping:
3491 text = "(%s)" % text
3492 return text
3493
3494 def _generate_generic_unary_operator(self, unary, opstring, **kw):
3495 return opstring + unary.element._compiler_dispatch(self, **kw)
3496
3497 def _generate_generic_unary_modifier(self, unary, opstring, **kw):
3498 return unary.element._compiler_dispatch(self, **kw) + opstring
3499
3500 @util.memoized_property
3501 def _like_percent_literal(self):
3502 return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
3503
3504 def visit_ilike_case_insensitive_operand(self, element, **kw):
3505 return f"lower({element.element._compiler_dispatch(self, **kw)})"
3506
3507 def visit_contains_op_binary(self, binary, operator, **kw):
3508 binary = binary._clone()
3509 percent = self._like_percent_literal
3510 binary.right = percent.concat(binary.right).concat(percent)
3511 return self.visit_like_op_binary(binary, operator, **kw)
3512
3513 def visit_not_contains_op_binary(self, binary, operator, **kw):
3514 binary = binary._clone()
3515 percent = self._like_percent_literal
3516 binary.right = percent.concat(binary.right).concat(percent)
3517 return self.visit_not_like_op_binary(binary, operator, **kw)
3518
3519 def visit_icontains_op_binary(self, binary, operator, **kw):
3520 binary = binary._clone()
3521 percent = self._like_percent_literal
3522 binary.left = ilike_case_insensitive(binary.left)
3523 binary.right = percent.concat(
3524 ilike_case_insensitive(binary.right)
3525 ).concat(percent)
3526 return self.visit_ilike_op_binary(binary, operator, **kw)
3527
3528 def visit_not_icontains_op_binary(self, binary, operator, **kw):
3529 binary = binary._clone()
3530 percent = self._like_percent_literal
3531 binary.left = ilike_case_insensitive(binary.left)
3532 binary.right = percent.concat(
3533 ilike_case_insensitive(binary.right)
3534 ).concat(percent)
3535 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3536
3537 def visit_startswith_op_binary(self, binary, operator, **kw):
3538 binary = binary._clone()
3539 percent = self._like_percent_literal
3540 binary.right = percent._rconcat(binary.right)
3541 return self.visit_like_op_binary(binary, operator, **kw)
3542
3543 def visit_not_startswith_op_binary(self, binary, operator, **kw):
3544 binary = binary._clone()
3545 percent = self._like_percent_literal
3546 binary.right = percent._rconcat(binary.right)
3547 return self.visit_not_like_op_binary(binary, operator, **kw)
3548
3549 def visit_istartswith_op_binary(self, binary, operator, **kw):
3550 binary = binary._clone()
3551 percent = self._like_percent_literal
3552 binary.left = ilike_case_insensitive(binary.left)
3553 binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
3554 return self.visit_ilike_op_binary(binary, operator, **kw)
3555
3556 def visit_not_istartswith_op_binary(self, binary, operator, **kw):
3557 binary = binary._clone()
3558 percent = self._like_percent_literal
3559 binary.left = ilike_case_insensitive(binary.left)
3560 binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
3561 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3562
3563 def visit_endswith_op_binary(self, binary, operator, **kw):
3564 binary = binary._clone()
3565 percent = self._like_percent_literal
3566 binary.right = percent.concat(binary.right)
3567 return self.visit_like_op_binary(binary, operator, **kw)
3568
3569 def visit_not_endswith_op_binary(self, binary, operator, **kw):
3570 binary = binary._clone()
3571 percent = self._like_percent_literal
3572 binary.right = percent.concat(binary.right)
3573 return self.visit_not_like_op_binary(binary, operator, **kw)
3574
3575 def visit_iendswith_op_binary(self, binary, operator, **kw):
3576 binary = binary._clone()
3577 percent = self._like_percent_literal
3578 binary.left = ilike_case_insensitive(binary.left)
3579 binary.right = percent.concat(ilike_case_insensitive(binary.right))
3580 return self.visit_ilike_op_binary(binary, operator, **kw)
3581
3582 def visit_not_iendswith_op_binary(self, binary, operator, **kw):
3583 binary = binary._clone()
3584 percent = self._like_percent_literal
3585 binary.left = ilike_case_insensitive(binary.left)
3586 binary.right = percent.concat(ilike_case_insensitive(binary.right))
3587 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3588
3589 def visit_like_op_binary(self, binary, operator, **kw):
3590 escape = binary.modifiers.get("escape", None)
3591
3592 return "%s LIKE %s" % (
3593 binary.left._compiler_dispatch(self, **kw),
3594 binary.right._compiler_dispatch(self, **kw),
3595 ) + (
3596 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
3597 if escape is not None
3598 else ""
3599 )
3600
3601 def visit_not_like_op_binary(self, binary, operator, **kw):
3602 escape = binary.modifiers.get("escape", None)
3603 return "%s NOT LIKE %s" % (
3604 binary.left._compiler_dispatch(self, **kw),
3605 binary.right._compiler_dispatch(self, **kw),
3606 ) + (
3607 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
3608 if escape is not None
3609 else ""
3610 )
3611
3612 def visit_ilike_op_binary(self, binary, operator, **kw):
3613 if operator is operators.ilike_op:
3614 binary = binary._clone()
3615 binary.left = ilike_case_insensitive(binary.left)
3616 binary.right = ilike_case_insensitive(binary.right)
3617 # else we assume ilower() has been applied
3618
3619 return self.visit_like_op_binary(binary, operator, **kw)
3620
3621 def visit_not_ilike_op_binary(self, binary, operator, **kw):
3622 if operator is operators.not_ilike_op:
3623 binary = binary._clone()
3624 binary.left = ilike_case_insensitive(binary.left)
3625 binary.right = ilike_case_insensitive(binary.right)
3626 # else we assume ilower() has been applied
3627
3628 return self.visit_not_like_op_binary(binary, operator, **kw)
3629
3630 def visit_between_op_binary(self, binary, operator, **kw):
3631 symmetric = binary.modifiers.get("symmetric", False)
3632 return self._generate_generic_binary(
3633 binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
3634 )
3635
3636 def visit_not_between_op_binary(self, binary, operator, **kw):
3637 symmetric = binary.modifiers.get("symmetric", False)
3638 return self._generate_generic_binary(
3639 binary,
3640 " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
3641 **kw,
3642 )
3643
3644 def visit_regexp_match_op_binary(
3645 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3646 ) -> str:
3647 raise exc.CompileError(
3648 "%s dialect does not support regular expressions"
3649 % self.dialect.name
3650 )
3651
3652 def visit_not_regexp_match_op_binary(
3653 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3654 ) -> str:
3655 raise exc.CompileError(
3656 "%s dialect does not support regular expressions"
3657 % self.dialect.name
3658 )
3659
3660 def visit_regexp_replace_op_binary(
3661 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3662 ) -> str:
3663 raise exc.CompileError(
3664 "%s dialect does not support regular expression replacements"
3665 % self.dialect.name
3666 )
3667
3668 def visit_bindparam(
3669 self,
3670 bindparam,
3671 within_columns_clause=False,
3672 literal_binds=False,
3673 skip_bind_expression=False,
3674 literal_execute=False,
3675 render_postcompile=False,
3676 **kwargs,
3677 ):
3678
3679 if not skip_bind_expression:
3680 impl = bindparam.type.dialect_impl(self.dialect)
3681 if impl._has_bind_expression:
3682 bind_expression = impl.bind_expression(bindparam)
3683 wrapped = self.process(
3684 bind_expression,
3685 skip_bind_expression=True,
3686 within_columns_clause=within_columns_clause,
3687 literal_binds=literal_binds and not bindparam.expanding,
3688 literal_execute=literal_execute,
3689 render_postcompile=render_postcompile,
3690 **kwargs,
3691 )
3692 if bindparam.expanding:
3693 # for postcompile w/ expanding, move the "wrapped" part
3694 # of this into the inside
3695
3696 m = re.match(
3697 r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
3698 )
3699 assert m, "unexpected format for expanding parameter"
3700 wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
3701 m.group(2),
3702 m.group(1),
3703 m.group(3),
3704 )
3705
3706 if literal_binds:
3707 ret = self.render_literal_bindparam(
3708 bindparam,
3709 within_columns_clause=True,
3710 bind_expression_template=wrapped,
3711 **kwargs,
3712 )
3713 return "(%s)" % ret
3714
3715 return wrapped
3716
3717 if not literal_binds:
3718 literal_execute = (
3719 literal_execute
3720 or bindparam.literal_execute
3721 or (within_columns_clause and self.ansi_bind_rules)
3722 )
3723 post_compile = literal_execute or bindparam.expanding
3724 else:
3725 post_compile = False
3726
3727 if literal_binds:
3728 ret = self.render_literal_bindparam(
3729 bindparam, within_columns_clause=True, **kwargs
3730 )
3731 if bindparam.expanding:
3732 ret = "(%s)" % ret
3733 return ret
3734
3735 name = self._truncate_bindparam(bindparam)
3736
3737 if name in self.binds:
3738 existing = self.binds[name]
3739 if existing is not bindparam:
3740 if (
3741 (existing.unique or bindparam.unique)
3742 and not existing.proxy_set.intersection(
3743 bindparam.proxy_set
3744 )
3745 and not existing._cloned_set.intersection(
3746 bindparam._cloned_set
3747 )
3748 ):
3749 raise exc.CompileError(
3750 "Bind parameter '%s' conflicts with "
3751 "unique bind parameter of the same name" % name
3752 )
3753 elif existing.expanding != bindparam.expanding:
3754 raise exc.CompileError(
3755 "Can't reuse bound parameter name '%s' in both "
3756 "'expanding' (e.g. within an IN expression) and "
3757 "non-expanding contexts. If this parameter is to "
3758 "receive a list/array value, set 'expanding=True' on "
3759 "it for expressions that aren't IN, otherwise use "
3760 "a different parameter name." % (name,)
3761 )
3762 elif existing._is_crud or bindparam._is_crud:
3763 if existing._is_crud and bindparam._is_crud:
3764 # TODO: this condition is not well understood.
3765 # see tests in test/sql/test_update.py
3766 raise exc.CompileError(
3767 "Encountered unsupported case when compiling an "
3768 "INSERT or UPDATE statement. If this is a "
3769 "multi-table "
3770 "UPDATE statement, please provide string-named "
3771 "arguments to the "
3772 "values() method with distinct names; support for "
3773 "multi-table UPDATE statements that "
3774 "target multiple tables for UPDATE is very "
3775 "limited",
3776 )
3777 else:
3778 raise exc.CompileError(
3779 f"bindparam() name '{bindparam.key}' is reserved "
3780 "for automatic usage in the VALUES or SET "
3781 "clause of this "
3782 "insert/update statement. Please use a "
3783 "name other than column name when using "
3784 "bindparam() "
3785 "with insert() or update() (for example, "
3786 f"'b_{bindparam.key}')."
3787 )
3788
3789 self.binds[bindparam.key] = self.binds[name] = bindparam
3790
3791 # if we are given a cache key that we're going to match against,
3792 # relate the bindparam here to one that is most likely present
3793 # in the "extracted params" portion of the cache key. this is used
3794 # to set up a positional mapping that is used to determine the
3795 # correct parameters for a subsequent use of this compiled with
3796 # a different set of parameter values. here, we accommodate for
3797 # parameters that may have been cloned both before and after the cache
3798 # key was been generated.
3799 ckbm_tuple = self._cache_key_bind_match
3800
3801 if ckbm_tuple:
3802 ckbm, cksm = ckbm_tuple
3803 for bp in bindparam._cloned_set:
3804 if bp.key in cksm:
3805 cb = cksm[bp.key]
3806 ckbm[cb].append(bindparam)
3807
3808 if bindparam.isoutparam:
3809 self.has_out_parameters = True
3810
3811 if post_compile:
3812 if render_postcompile:
3813 self._render_postcompile = True
3814
3815 if literal_execute:
3816 self.literal_execute_params |= {bindparam}
3817 else:
3818 self.post_compile_params |= {bindparam}
3819
3820 ret = self.bindparam_string(
3821 name,
3822 post_compile=post_compile,
3823 expanding=bindparam.expanding,
3824 bindparam_type=bindparam.type,
3825 **kwargs,
3826 )
3827
3828 if bindparam.expanding:
3829 ret = "(%s)" % ret
3830
3831 return ret
3832
3833 def render_bind_cast(self, type_, dbapi_type, sqltext):
3834 raise NotImplementedError()
3835
3836 def render_literal_bindparam(
3837 self,
3838 bindparam,
3839 render_literal_value=NO_ARG,
3840 bind_expression_template=None,
3841 **kw,
3842 ):
3843 if render_literal_value is not NO_ARG:
3844 value = render_literal_value
3845 else:
3846 if bindparam.value is None and bindparam.callable is None:
3847 op = kw.get("_binary_op", None)
3848 if op and op not in (operators.is_, operators.is_not):
3849 util.warn_limited(
3850 "Bound parameter '%s' rendering literal NULL in a SQL "
3851 "expression; comparisons to NULL should not use "
3852 "operators outside of 'is' or 'is not'",
3853 (bindparam.key,),
3854 )
3855 return self.process(sqltypes.NULLTYPE, **kw)
3856 value = bindparam.effective_value
3857
3858 if bindparam.expanding:
3859 leep = self._literal_execute_expanding_parameter_literal_binds
3860 to_update, replacement_expr = leep(
3861 bindparam,
3862 value,
3863 bind_expression_template=bind_expression_template,
3864 )
3865 return replacement_expr
3866 else:
3867 return self.render_literal_value(value, bindparam.type)
3868
3869 def render_literal_value(
3870 self, value: Any, type_: sqltypes.TypeEngine[Any]
3871 ) -> str:
3872 """Render the value of a bind parameter as a quoted literal.
3873
3874 This is used for statement sections that do not accept bind parameters
3875 on the target driver/database.
3876
3877 This should be implemented by subclasses using the quoting services
3878 of the DBAPI.
3879
3880 """
3881
3882 if value is None and not type_.should_evaluate_none:
3883 # issue #10535 - handle NULL in the compiler without placing
3884 # this onto each type, except for "evaluate None" types
3885 # (e.g. JSON)
3886 return self.process(elements.Null._instance())
3887
3888 processor = type_._cached_literal_processor(self.dialect)
3889 if processor:
3890 try:
3891 return processor(value)
3892 except Exception as e:
3893 raise exc.CompileError(
3894 f"Could not render literal value "
3895 f'"{sql_util._repr_single_value(value)}" '
3896 f"with datatype "
3897 f"{type_}; see parent stack trace for "
3898 "more detail."
3899 ) from e
3900
3901 else:
3902 raise exc.CompileError(
3903 f"No literal value renderer is available for literal value "
3904 f'"{sql_util._repr_single_value(value)}" '
3905 f"with datatype {type_}"
3906 )
3907
3908 def _truncate_bindparam(self, bindparam):
3909 if bindparam in self.bind_names:
3910 return self.bind_names[bindparam]
3911
3912 bind_name = bindparam.key
3913 if isinstance(bind_name, elements._truncated_label):
3914 bind_name = self._truncated_identifier("bindparam", bind_name)
3915
3916 # add to bind_names for translation
3917 self.bind_names[bindparam] = bind_name
3918
3919 return bind_name
3920
3921 def _truncated_identifier(
3922 self, ident_class: str, name: _truncated_label
3923 ) -> str:
3924 if (ident_class, name) in self.truncated_names:
3925 return self.truncated_names[(ident_class, name)]
3926
3927 anonname = name.apply_map(self.anon_map)
3928
3929 if len(anonname) > self.label_length - 6:
3930 counter = self._truncated_counters.get(ident_class, 1)
3931 truncname = (
3932 anonname[0 : max(self.label_length - 6, 0)]
3933 + "_"
3934 + hex(counter)[2:]
3935 )
3936 self._truncated_counters[ident_class] = counter + 1
3937 else:
3938 truncname = anonname
3939 self.truncated_names[(ident_class, name)] = truncname
3940 return truncname
3941
3942 def _anonymize(self, name: str) -> str:
3943 return name % self.anon_map
3944
3945 def bindparam_string(
3946 self,
3947 name: str,
3948 post_compile: bool = False,
3949 expanding: bool = False,
3950 escaped_from: Optional[str] = None,
3951 bindparam_type: Optional[TypeEngine[Any]] = None,
3952 accumulate_bind_names: Optional[Set[str]] = None,
3953 visited_bindparam: Optional[List[str]] = None,
3954 **kw: Any,
3955 ) -> str:
3956 # TODO: accumulate_bind_names is passed by crud.py to gather
3957 # names on a per-value basis, visited_bindparam is passed by
3958 # visit_insert() to collect all parameters in the statement.
3959 # see if this gathering can be simplified somehow
3960 if accumulate_bind_names is not None:
3961 accumulate_bind_names.add(name)
3962 if visited_bindparam is not None:
3963 visited_bindparam.append(name)
3964
3965 if not escaped_from:
3966 if self._bind_translate_re.search(name):
3967 # not quite the translate use case as we want to
3968 # also get a quick boolean if we even found
3969 # unusual characters in the name
3970 new_name = self._bind_translate_re.sub(
3971 lambda m: self._bind_translate_chars[m.group(0)],
3972 name,
3973 )
3974 escaped_from = name
3975 name = new_name
3976
3977 if escaped_from:
3978 self.escaped_bind_names = self.escaped_bind_names.union(
3979 {escaped_from: name}
3980 )
3981 if post_compile:
3982 ret = "__[POSTCOMPILE_%s]" % name
3983 if expanding:
3984 # for expanding, bound parameters or literal values will be
3985 # rendered per item
3986 return ret
3987
3988 # otherwise, for non-expanding "literal execute", apply
3989 # bind casts as determined by the datatype
3990 if bindparam_type is not None:
3991 type_impl = bindparam_type._unwrapped_dialect_impl(
3992 self.dialect
3993 )
3994 if type_impl.render_literal_cast:
3995 ret = self.render_bind_cast(bindparam_type, type_impl, ret)
3996 return ret
3997 elif self.state is CompilerState.COMPILING:
3998 ret = self.compilation_bindtemplate % {"name": name}
3999 else:
4000 ret = self.bindtemplate % {"name": name}
4001
4002 if (
4003 bindparam_type is not None
4004 and self.dialect._bind_typing_render_casts
4005 ):
4006 type_impl = bindparam_type._unwrapped_dialect_impl(self.dialect)
4007 if type_impl.render_bind_cast:
4008 ret = self.render_bind_cast(bindparam_type, type_impl, ret)
4009
4010 return ret
4011
4012 def _dispatch_independent_ctes(self, stmt, kw):
4013 local_kw = kw.copy()
4014 local_kw.pop("cte_opts", None)
4015 for cte, opt in zip(
4016 stmt._independent_ctes, stmt._independent_ctes_opts
4017 ):
4018 cte._compiler_dispatch(self, cte_opts=opt, **local_kw)
4019
4020 def visit_cte(
4021 self,
4022 cte: CTE,
4023 asfrom: bool = False,
4024 ashint: bool = False,
4025 fromhints: Optional[_FromHintsType] = None,
4026 visiting_cte: Optional[CTE] = None,
4027 from_linter: Optional[FromLinter] = None,
4028 cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
4029 **kwargs: Any,
4030 ) -> Optional[str]:
4031 self_ctes = self._init_cte_state()
4032 assert self_ctes is self.ctes
4033
4034 kwargs["visiting_cte"] = cte
4035
4036 cte_name = cte.name
4037
4038 if isinstance(cte_name, elements._truncated_label):
4039 cte_name = self._truncated_identifier("alias", cte_name)
4040
4041 is_new_cte = True
4042 embedded_in_current_named_cte = False
4043
4044 _reference_cte = cte._get_reference_cte()
4045
4046 nesting = cte.nesting or cte_opts.nesting
4047
4048 # check for CTE already encountered
4049 if _reference_cte in self.level_name_by_cte:
4050 cte_level, _, existing_cte_opts = self.level_name_by_cte[
4051 _reference_cte
4052 ]
4053 assert _ == cte_name
4054
4055 cte_level_name = (cte_level, cte_name)
4056 existing_cte = self.ctes_by_level_name[cte_level_name]
4057
4058 # check if we are receiving it here with a specific
4059 # "nest_here" location; if so, move it to this location
4060
4061 if cte_opts.nesting:
4062 if existing_cte_opts.nesting:
4063 raise exc.CompileError(
4064 "CTE is stated as 'nest_here' in "
4065 "more than one location"
4066 )
4067
4068 old_level_name = (cte_level, cte_name)
4069 cte_level = len(self.stack) if nesting else 1
4070 cte_level_name = new_level_name = (cte_level, cte_name)
4071
4072 del self.ctes_by_level_name[old_level_name]
4073 self.ctes_by_level_name[new_level_name] = existing_cte
4074 self.level_name_by_cte[_reference_cte] = new_level_name + (
4075 cte_opts,
4076 )
4077
4078 else:
4079 cte_level = len(self.stack) if nesting else 1
4080 cte_level_name = (cte_level, cte_name)
4081
4082 if cte_level_name in self.ctes_by_level_name:
4083 existing_cte = self.ctes_by_level_name[cte_level_name]
4084 else:
4085 existing_cte = None
4086
4087 if existing_cte is not None:
4088 embedded_in_current_named_cte = visiting_cte is existing_cte
4089
4090 # we've generated a same-named CTE that we are enclosed in,
4091 # or this is the same CTE. just return the name.
4092 if cte is existing_cte._restates or cte is existing_cte:
4093 is_new_cte = False
4094 elif existing_cte is cte._restates:
4095 # we've generated a same-named CTE that is
4096 # enclosed in us - we take precedence, so
4097 # discard the text for the "inner".
4098 del self_ctes[existing_cte]
4099
4100 existing_cte_reference_cte = existing_cte._get_reference_cte()
4101
4102 assert existing_cte_reference_cte is _reference_cte
4103 assert existing_cte_reference_cte is existing_cte
4104
4105 del self.level_name_by_cte[existing_cte_reference_cte]
4106 else:
4107 if (
4108 # if the two CTEs have the same hash, which we expect
4109 # here means that one/both is an annotated of the other
4110 (hash(cte) == hash(existing_cte))
4111 # or...
4112 or (
4113 (
4114 # if they are clones, i.e. they came from the ORM
4115 # or some other visit method
4116 cte._is_clone_of is not None
4117 or existing_cte._is_clone_of is not None
4118 )
4119 # and are deep-copy identical
4120 and cte.compare(existing_cte)
4121 )
4122 ):
4123 # then consider these two CTEs the same
4124 is_new_cte = False
4125 else:
4126 # otherwise these are two CTEs that either will render
4127 # differently, or were indicated separately by the user,
4128 # with the same name
4129 raise exc.CompileError(
4130 "Multiple, unrelated CTEs found with "
4131 "the same name: %r" % cte_name
4132 )
4133
4134 if not asfrom and not is_new_cte:
4135 return None
4136
4137 if cte._cte_alias is not None:
4138 pre_alias_cte = cte._cte_alias
4139 cte_pre_alias_name = cte._cte_alias.name
4140 if isinstance(cte_pre_alias_name, elements._truncated_label):
4141 cte_pre_alias_name = self._truncated_identifier(
4142 "alias", cte_pre_alias_name
4143 )
4144 else:
4145 pre_alias_cte = cte
4146 cte_pre_alias_name = None
4147
4148 if is_new_cte:
4149 self.ctes_by_level_name[cte_level_name] = cte
4150 self.level_name_by_cte[_reference_cte] = cte_level_name + (
4151 cte_opts,
4152 )
4153
4154 if pre_alias_cte not in self.ctes:
4155 self.visit_cte(pre_alias_cte, **kwargs)
4156
4157 if not cte_pre_alias_name and cte not in self_ctes:
4158 if cte.recursive:
4159 self.ctes_recursive = True
4160 text = self.preparer.format_alias(cte, cte_name)
4161 if cte.recursive or cte.element.name_cte_columns:
4162 col_source = cte.element
4163
4164 # TODO: can we get at the .columns_plus_names collection
4165 # that is already (or will be?) generated for the SELECT
4166 # rather than calling twice?
4167 recur_cols = [
4168 # TODO: proxy_name is not technically safe,
4169 # see test_cte->
4170 # test_with_recursive_no_name_currently_buggy. not
4171 # clear what should be done with such a case
4172 fallback_label_name or proxy_name
4173 for (
4174 _,
4175 proxy_name,
4176 fallback_label_name,
4177 c,
4178 repeated,
4179 ) in (col_source._generate_columns_plus_names(True))
4180 if not repeated
4181 ]
4182
4183 text += "(%s)" % (
4184 ", ".join(
4185 self.preparer.format_label_name(
4186 ident, anon_map=self.anon_map
4187 )
4188 for ident in recur_cols
4189 )
4190 )
4191
4192 assert kwargs.get("subquery", False) is False
4193
4194 if not self.stack:
4195 # toplevel, this is a stringify of the
4196 # cte directly. just compile the inner
4197 # the way alias() does.
4198 return cte.element._compiler_dispatch(
4199 self, asfrom=asfrom, **kwargs
4200 )
4201 else:
4202 prefixes = self._generate_prefixes(
4203 cte, cte._prefixes, **kwargs
4204 )
4205 inner = cte.element._compiler_dispatch(
4206 self, asfrom=True, **kwargs
4207 )
4208
4209 text += " AS %s\n(%s)" % (prefixes, inner)
4210
4211 if cte._suffixes:
4212 text += " " + self._generate_prefixes(
4213 cte, cte._suffixes, **kwargs
4214 )
4215
4216 self_ctes[cte] = text
4217
4218 if asfrom:
4219 if from_linter:
4220 from_linter.froms[cte._de_clone()] = cte_name
4221
4222 if not is_new_cte and embedded_in_current_named_cte:
4223 return self.preparer.format_alias(cte, cte_name)
4224
4225 if cte_pre_alias_name:
4226 text = self.preparer.format_alias(cte, cte_pre_alias_name)
4227 if self.preparer._requires_quotes(cte_name):
4228 cte_name = self.preparer.quote(cte_name)
4229 text += self.get_render_as_alias_suffix(cte_name)
4230 return text # type: ignore[no-any-return]
4231 else:
4232 return self.preparer.format_alias(cte, cte_name)
4233
4234 return None
4235
4236 def visit_table_valued_alias(self, element, **kw):
4237 if element.joins_implicitly:
4238 kw["from_linter"] = None
4239 if element._is_lateral:
4240 return self.visit_lateral(element, **kw)
4241 else:
4242 return self.visit_alias(element, **kw)
4243
4244 def visit_table_valued_column(self, element, **kw):
4245 return self.visit_column(element, **kw)
4246
4247 def visit_alias(
4248 self,
4249 alias,
4250 asfrom=False,
4251 ashint=False,
4252 iscrud=False,
4253 fromhints=None,
4254 subquery=False,
4255 lateral=False,
4256 enclosing_alias=None,
4257 from_linter=None,
4258 **kwargs,
4259 ):
4260 if lateral:
4261 if "enclosing_lateral" not in kwargs:
4262 # if lateral is set and enclosing_lateral is not
4263 # present, we assume we are being called directly
4264 # from visit_lateral() and we need to set enclosing_lateral.
4265 assert alias._is_lateral
4266 kwargs["enclosing_lateral"] = alias
4267
4268 # for lateral objects, we track a second from_linter that is...
4269 # lateral! to the level above us.
4270 if (
4271 from_linter
4272 and "lateral_from_linter" not in kwargs
4273 and "enclosing_lateral" in kwargs
4274 ):
4275 kwargs["lateral_from_linter"] = from_linter
4276
4277 if enclosing_alias is not None and enclosing_alias.element is alias:
4278 inner = alias.element._compiler_dispatch(
4279 self,
4280 asfrom=asfrom,
4281 ashint=ashint,
4282 iscrud=iscrud,
4283 fromhints=fromhints,
4284 lateral=lateral,
4285 enclosing_alias=alias,
4286 **kwargs,
4287 )
4288 if subquery and (asfrom or lateral):
4289 inner = "(%s)" % (inner,)
4290 return inner
4291 else:
4292 kwargs["enclosing_alias"] = alias
4293
4294 if asfrom or ashint:
4295 if isinstance(alias.name, elements._truncated_label):
4296 alias_name = self._truncated_identifier("alias", alias.name)
4297 else:
4298 alias_name = alias.name
4299
4300 if ashint:
4301 return self.preparer.format_alias(alias, alias_name)
4302 elif asfrom:
4303 if from_linter:
4304 from_linter.froms[alias._de_clone()] = alias_name
4305
4306 inner = alias.element._compiler_dispatch(
4307 self, asfrom=True, lateral=lateral, **kwargs
4308 )
4309 if subquery:
4310 inner = "(%s)" % (inner,)
4311
4312 ret = inner + self.get_render_as_alias_suffix(
4313 self.preparer.format_alias(alias, alias_name)
4314 )
4315
4316 if alias._supports_derived_columns and alias._render_derived:
4317 ret += "(%s)" % (
4318 ", ".join(
4319 "%s%s"
4320 % (
4321 self.preparer.quote(col.name),
4322 (
4323 " %s"
4324 % self.dialect.type_compiler_instance.process(
4325 col.type, **kwargs
4326 )
4327 if alias._render_derived_w_types
4328 else ""
4329 ),
4330 )
4331 for col in alias.c
4332 )
4333 )
4334
4335 if fromhints and alias in fromhints:
4336 ret = self.format_from_hint_text(
4337 ret, alias, fromhints[alias], iscrud
4338 )
4339
4340 return ret
4341 else:
4342 # note we cancel the "subquery" flag here as well
4343 return alias.element._compiler_dispatch(
4344 self, lateral=lateral, **kwargs
4345 )
4346
4347 def visit_subquery(self, subquery, **kw):
4348 kw["subquery"] = True
4349 return self.visit_alias(subquery, **kw)
4350
4351 def visit_lateral(self, lateral_, **kw):
4352 kw["lateral"] = True
4353 return "LATERAL %s" % self.visit_alias(lateral_, **kw)
4354
4355 def visit_tablesample(self, tablesample, asfrom=False, **kw):
4356 text = "%s TABLESAMPLE %s" % (
4357 self.visit_alias(tablesample, asfrom=True, **kw),
4358 tablesample._get_method()._compiler_dispatch(self, **kw),
4359 )
4360
4361 if tablesample.seed is not None:
4362 text += " REPEATABLE (%s)" % (
4363 tablesample.seed._compiler_dispatch(self, **kw)
4364 )
4365
4366 return text
4367
4368 def _render_values(self, element, **kw):
4369 kw.setdefault("literal_binds", element.literal_binds)
4370 tuples = ", ".join(
4371 self.process(
4372 elements.Tuple(
4373 types=element._column_types, *elem
4374 ).self_group(),
4375 **kw,
4376 )
4377 for chunk in element._data
4378 for elem in chunk
4379 )
4380 return f"VALUES {tuples}"
4381
4382 def visit_values(
4383 self, element, asfrom=False, from_linter=None, visiting_cte=None, **kw
4384 ):
4385
4386 if element._independent_ctes:
4387 self._dispatch_independent_ctes(element, kw)
4388
4389 v = self._render_values(element, **kw)
4390
4391 if element._unnamed:
4392 name = None
4393 elif isinstance(element.name, elements._truncated_label):
4394 name = self._truncated_identifier("values", element.name)
4395 else:
4396 name = element.name
4397
4398 if element._is_lateral:
4399 lateral = "LATERAL "
4400 else:
4401 lateral = ""
4402
4403 if asfrom:
4404 if from_linter:
4405 from_linter.froms[element._de_clone()] = (
4406 name if name is not None else "(unnamed VALUES element)"
4407 )
4408
4409 if visiting_cte is not None and visiting_cte.element is element:
4410 if element._is_lateral:
4411 raise exc.CompileError(
4412 "Can't use a LATERAL VALUES expression inside of a CTE"
4413 )
4414 elif name:
4415 kw["include_table"] = False
4416 v = "%s(%s)%s (%s)" % (
4417 lateral,
4418 v,
4419 self.get_render_as_alias_suffix(self.preparer.quote(name)),
4420 (
4421 ", ".join(
4422 c._compiler_dispatch(self, **kw)
4423 for c in element.columns
4424 )
4425 ),
4426 )
4427 else:
4428 v = "%s(%s)" % (lateral, v)
4429 return v
4430
4431 def visit_scalar_values(self, element, **kw):
4432 return f"({self._render_values(element, **kw)})"
4433
4434 def get_render_as_alias_suffix(self, alias_name_text):
4435 return " AS " + alias_name_text
4436
4437 def _add_to_result_map(
4438 self,
4439 keyname: str,
4440 name: str,
4441 objects: Tuple[Any, ...],
4442 type_: TypeEngine[Any],
4443 ) -> None:
4444
4445 # note objects must be non-empty for cursor.py to handle the
4446 # collection properly
4447 assert objects
4448
4449 if keyname is None or keyname == "*":
4450 self._ordered_columns = False
4451 self._ad_hoc_textual = True
4452 if type_._is_tuple_type:
4453 raise exc.CompileError(
4454 "Most backends don't support SELECTing "
4455 "from a tuple() object. If this is an ORM query, "
4456 "consider using the Bundle object."
4457 )
4458 self._result_columns.append(
4459 ResultColumnsEntry(keyname, name, objects, type_)
4460 )
4461
4462 def _label_returning_column(
4463 self, stmt, column, populate_result_map, column_clause_args=None, **kw
4464 ):
4465 """Render a column with necessary labels inside of a RETURNING clause.
4466
4467 This method is provided for individual dialects in place of calling
4468 the _label_select_column method directly, so that the two use cases
4469 of RETURNING vs. SELECT can be disambiguated going forward.
4470
4471 .. versionadded:: 1.4.21
4472
4473 """
4474 return self._label_select_column(
4475 None,
4476 column,
4477 populate_result_map,
4478 False,
4479 {} if column_clause_args is None else column_clause_args,
4480 **kw,
4481 )
4482
4483 def _label_select_column(
4484 self,
4485 select,
4486 column,
4487 populate_result_map,
4488 asfrom,
4489 column_clause_args,
4490 name=None,
4491 proxy_name=None,
4492 fallback_label_name=None,
4493 within_columns_clause=True,
4494 column_is_repeated=False,
4495 need_column_expressions=False,
4496 include_table=True,
4497 ):
4498 """produce labeled columns present in a select()."""
4499 impl = column.type.dialect_impl(self.dialect)
4500
4501 if impl._has_column_expression and (
4502 need_column_expressions or populate_result_map
4503 ):
4504 col_expr = impl.column_expression(column)
4505 else:
4506 col_expr = column
4507
4508 if populate_result_map:
4509 # pass an "add_to_result_map" callable into the compilation
4510 # of embedded columns. this collects information about the
4511 # column as it will be fetched in the result and is coordinated
4512 # with cursor.description when the query is executed.
4513 add_to_result_map = self._add_to_result_map
4514
4515 # if the SELECT statement told us this column is a repeat,
4516 # wrap the callable with one that prevents the addition of the
4517 # targets
4518 if column_is_repeated:
4519 _add_to_result_map = add_to_result_map
4520
4521 def add_to_result_map(keyname, name, objects, type_):
4522 _add_to_result_map(keyname, name, (keyname,), type_)
4523
4524 # if we redefined col_expr for type expressions, wrap the
4525 # callable with one that adds the original column to the targets
4526 elif col_expr is not column:
4527 _add_to_result_map = add_to_result_map
4528
4529 def add_to_result_map(keyname, name, objects, type_):
4530 _add_to_result_map(
4531 keyname, name, (column,) + objects, type_
4532 )
4533
4534 else:
4535 add_to_result_map = None
4536
4537 # this method is used by some of the dialects for RETURNING,
4538 # which has different inputs. _label_returning_column was added
4539 # as the better target for this now however for 1.4 we will keep
4540 # _label_select_column directly compatible with this use case.
4541 # these assertions right now set up the current expected inputs
4542 assert within_columns_clause, (
4543 "_label_select_column is only relevant within "
4544 "the columns clause of a SELECT or RETURNING"
4545 )
4546 if isinstance(column, elements.Label):
4547 if col_expr is not column:
4548 result_expr = _CompileLabel(
4549 col_expr, column.name, alt_names=(column.element,)
4550 )
4551 else:
4552 result_expr = col_expr
4553
4554 elif name:
4555 # here, _columns_plus_names has determined there's an explicit
4556 # label name we need to use. this is the default for
4557 # tablenames_plus_columnnames as well as when columns are being
4558 # deduplicated on name
4559
4560 assert (
4561 proxy_name is not None
4562 ), "proxy_name is required if 'name' is passed"
4563
4564 result_expr = _CompileLabel(
4565 col_expr,
4566 name,
4567 alt_names=(
4568 proxy_name,
4569 # this is a hack to allow legacy result column lookups
4570 # to work as they did before; this goes away in 2.0.
4571 # TODO: this only seems to be tested indirectly
4572 # via test/orm/test_deprecations.py. should be a
4573 # resultset test for this
4574 column._tq_label,
4575 ),
4576 )
4577 else:
4578 # determine here whether this column should be rendered in
4579 # a labelled context or not, as we were given no required label
4580 # name from the caller. Here we apply heuristics based on the kind
4581 # of SQL expression involved.
4582
4583 if col_expr is not column:
4584 # type-specific expression wrapping the given column,
4585 # so we render a label
4586 render_with_label = True
4587 elif isinstance(column, elements.ColumnClause):
4588 # table-bound column, we render its name as a label if we are
4589 # inside of a subquery only
4590 render_with_label = (
4591 asfrom
4592 and not column.is_literal
4593 and column.table is not None
4594 )
4595 elif isinstance(column, elements.TextClause):
4596 render_with_label = False
4597 elif isinstance(column, elements.UnaryExpression):
4598 # unary expression. notes added as of #12681
4599 #
4600 # By convention, the visit_unary() method
4601 # itself does not add an entry to the result map, and relies
4602 # upon either the inner expression creating a result map
4603 # entry, or if not, by creating a label here that produces
4604 # the result map entry. Where that happens is based on whether
4605 # or not the element immediately inside the unary is a
4606 # NamedColumn subclass or not.
4607 #
4608 # Now, this also impacts how the SELECT is written; if
4609 # we decide to generate a label here, we get the usual
4610 # "~(x+y) AS anon_1" thing in the columns clause. If we
4611 # don't, we don't get an AS at all, we get like
4612 # "~table.column".
4613 #
4614 # But here is the important thing as of modernish (like 1.4)
4615 # versions of SQLAlchemy - **whether or not the AS <label>
4616 # is present in the statement is not actually important**.
4617 # We target result columns **positionally** for a fully
4618 # compiled ``Select()`` object; before 1.4 we needed those
4619 # labels to match in cursor.description etc etc but now it
4620 # really doesn't matter.
4621 # So really, we could set render_with_label True in all cases.
4622 # Or we could just have visit_unary() populate the result map
4623 # in all cases.
4624 #
4625 # What we're doing here is strictly trying to not rock the
4626 # boat too much with when we do/don't render "AS label";
4627 # labels being present helps in the edge cases that we
4628 # "fall back" to named cursor.description matching, labels
4629 # not being present for columns keeps us from having awkward
4630 # phrases like "SELECT DISTINCT table.x AS x".
4631 render_with_label = (
4632 (
4633 # exception case to detect if we render "not boolean"
4634 # as "not <col>" for native boolean or "<col> = 1"
4635 # for non-native boolean. this is controlled by
4636 # visit_is_<true|false>_unary_operator
4637 column.operator
4638 in (operators.is_false, operators.is_true)
4639 and not self.dialect.supports_native_boolean
4640 )
4641 or column._wraps_unnamed_column()
4642 or asfrom
4643 )
4644 elif (
4645 # general class of expressions that don't have a SQL-column
4646 # addressible name. includes scalar selects, bind parameters,
4647 # SQL functions, others
4648 not isinstance(column, elements.NamedColumn)
4649 # deeper check that indicates there's no natural "name" to
4650 # this element, which accommodates for custom SQL constructs
4651 # that might have a ".name" attribute (but aren't SQL
4652 # functions) but are not implementing this more recently added
4653 # base class. in theory the "NamedColumn" check should be
4654 # enough, however here we seek to maintain legacy behaviors
4655 # as well.
4656 and column._non_anon_label is None
4657 ):
4658 render_with_label = True
4659 else:
4660 render_with_label = False
4661
4662 if render_with_label:
4663 if not fallback_label_name:
4664 # used by the RETURNING case right now. we generate it
4665 # here as 3rd party dialects may be referring to
4666 # _label_select_column method directly instead of the
4667 # just-added _label_returning_column method
4668 assert not column_is_repeated
4669 fallback_label_name = column._anon_name_label
4670
4671 fallback_label_name = (
4672 elements._truncated_label(fallback_label_name)
4673 if not isinstance(
4674 fallback_label_name, elements._truncated_label
4675 )
4676 else fallback_label_name
4677 )
4678
4679 result_expr = _CompileLabel(
4680 col_expr, fallback_label_name, alt_names=(proxy_name,)
4681 )
4682 else:
4683 result_expr = col_expr
4684
4685 column_clause_args.update(
4686 within_columns_clause=within_columns_clause,
4687 add_to_result_map=add_to_result_map,
4688 include_table=include_table,
4689 )
4690 return result_expr._compiler_dispatch(self, **column_clause_args)
4691
4692 def format_from_hint_text(self, sqltext, table, hint, iscrud):
4693 hinttext = self.get_from_hint_text(table, hint)
4694 if hinttext:
4695 sqltext += " " + hinttext
4696 return sqltext
4697
4698 def get_select_hint_text(self, byfroms):
4699 return None
4700
4701 def get_from_hint_text(
4702 self, table: FromClause, text: Optional[str]
4703 ) -> Optional[str]:
4704 return None
4705
4706 def get_crud_hint_text(self, table, text):
4707 return None
4708
4709 def get_statement_hint_text(self, hint_texts):
4710 return " ".join(hint_texts)
4711
4712 _default_stack_entry: _CompilerStackEntry
4713
4714 if not typing.TYPE_CHECKING:
4715 _default_stack_entry = util.immutabledict(
4716 [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
4717 )
4718
4719 def _display_froms_for_select(
4720 self, select_stmt, asfrom, lateral=False, **kw
4721 ):
4722 # utility method to help external dialects
4723 # get the correct from list for a select.
4724 # specifically the oracle dialect needs this feature
4725 # right now.
4726 toplevel = not self.stack
4727 entry = self._default_stack_entry if toplevel else self.stack[-1]
4728
4729 compile_state = select_stmt._compile_state_factory(select_stmt, self)
4730
4731 correlate_froms = entry["correlate_froms"]
4732 asfrom_froms = entry["asfrom_froms"]
4733
4734 if asfrom and not lateral:
4735 froms = compile_state._get_display_froms(
4736 explicit_correlate_froms=correlate_froms.difference(
4737 asfrom_froms
4738 ),
4739 implicit_correlate_froms=(),
4740 )
4741 else:
4742 froms = compile_state._get_display_froms(
4743 explicit_correlate_froms=correlate_froms,
4744 implicit_correlate_froms=asfrom_froms,
4745 )
4746 return froms
4747
4748 translate_select_structure: Any = None
4749 """if not ``None``, should be a callable which accepts ``(select_stmt,
4750 **kw)`` and returns a select object. this is used for structural changes
4751 mostly to accommodate for LIMIT/OFFSET schemes
4752
4753 """
4754
4755 def visit_select(
4756 self,
4757 select_stmt,
4758 asfrom=False,
4759 insert_into=False,
4760 fromhints=None,
4761 compound_index=None,
4762 select_wraps_for=None,
4763 lateral=False,
4764 from_linter=None,
4765 **kwargs,
4766 ):
4767 assert select_wraps_for is None, (
4768 "SQLAlchemy 1.4 requires use of "
4769 "the translate_select_structure hook for structural "
4770 "translations of SELECT objects"
4771 )
4772
4773 # initial setup of SELECT. the compile_state_factory may now
4774 # be creating a totally different SELECT from the one that was
4775 # passed in. for ORM use this will convert from an ORM-state
4776 # SELECT to a regular "Core" SELECT. other composed operations
4777 # such as computation of joins will be performed.
4778
4779 kwargs["within_columns_clause"] = False
4780
4781 compile_state = select_stmt._compile_state_factory(
4782 select_stmt, self, **kwargs
4783 )
4784 kwargs["ambiguous_table_name_map"] = (
4785 compile_state._ambiguous_table_name_map
4786 )
4787
4788 select_stmt = compile_state.statement
4789
4790 toplevel = not self.stack
4791
4792 if toplevel and not self.compile_state:
4793 self.compile_state = compile_state
4794
4795 is_embedded_select = compound_index is not None or insert_into
4796
4797 # translate step for Oracle, SQL Server which often need to
4798 # restructure the SELECT to allow for LIMIT/OFFSET and possibly
4799 # other conditions
4800 if self.translate_select_structure:
4801 new_select_stmt = self.translate_select_structure(
4802 select_stmt, asfrom=asfrom, **kwargs
4803 )
4804
4805 # if SELECT was restructured, maintain a link to the originals
4806 # and assemble a new compile state
4807 if new_select_stmt is not select_stmt:
4808 compile_state_wraps_for = compile_state
4809 select_wraps_for = select_stmt
4810 select_stmt = new_select_stmt
4811
4812 compile_state = select_stmt._compile_state_factory(
4813 select_stmt, self, **kwargs
4814 )
4815 select_stmt = compile_state.statement
4816
4817 entry = self._default_stack_entry if toplevel else self.stack[-1]
4818
4819 populate_result_map = need_column_expressions = (
4820 toplevel
4821 or entry.get("need_result_map_for_compound", False)
4822 or entry.get("need_result_map_for_nested", False)
4823 )
4824
4825 # indicates there is a CompoundSelect in play and we are not the
4826 # first select
4827 if compound_index:
4828 populate_result_map = False
4829
4830 # this was first proposed as part of #3372; however, it is not
4831 # reached in current tests and could possibly be an assertion
4832 # instead.
4833 if not populate_result_map and "add_to_result_map" in kwargs:
4834 del kwargs["add_to_result_map"]
4835
4836 froms = self._setup_select_stack(
4837 select_stmt, compile_state, entry, asfrom, lateral, compound_index
4838 )
4839
4840 column_clause_args = kwargs.copy()
4841 column_clause_args.update(
4842 {"within_label_clause": False, "within_columns_clause": False}
4843 )
4844
4845 text = "SELECT " # we're off to a good start !
4846
4847 if select_stmt._hints:
4848 hint_text, byfrom = self._setup_select_hints(select_stmt)
4849 if hint_text:
4850 text += hint_text + " "
4851 else:
4852 byfrom = None
4853
4854 if select_stmt._independent_ctes:
4855 self._dispatch_independent_ctes(select_stmt, kwargs)
4856
4857 if select_stmt._prefixes:
4858 text += self._generate_prefixes(
4859 select_stmt, select_stmt._prefixes, **kwargs
4860 )
4861
4862 text += self.get_select_precolumns(select_stmt, **kwargs)
4863 # the actual list of columns to print in the SELECT column list.
4864 inner_columns = [
4865 c
4866 for c in [
4867 self._label_select_column(
4868 select_stmt,
4869 column,
4870 populate_result_map,
4871 asfrom,
4872 column_clause_args,
4873 name=name,
4874 proxy_name=proxy_name,
4875 fallback_label_name=fallback_label_name,
4876 column_is_repeated=repeated,
4877 need_column_expressions=need_column_expressions,
4878 )
4879 for (
4880 name,
4881 proxy_name,
4882 fallback_label_name,
4883 column,
4884 repeated,
4885 ) in compile_state.columns_plus_names
4886 ]
4887 if c is not None
4888 ]
4889
4890 if populate_result_map and select_wraps_for is not None:
4891 # if this select was generated from translate_select,
4892 # rewrite the targeted columns in the result map
4893
4894 translate = dict(
4895 zip(
4896 [
4897 name
4898 for (
4899 key,
4900 proxy_name,
4901 fallback_label_name,
4902 name,
4903 repeated,
4904 ) in compile_state.columns_plus_names
4905 ],
4906 [
4907 name
4908 for (
4909 key,
4910 proxy_name,
4911 fallback_label_name,
4912 name,
4913 repeated,
4914 ) in compile_state_wraps_for.columns_plus_names
4915 ],
4916 )
4917 )
4918
4919 self._result_columns = [
4920 ResultColumnsEntry(
4921 key, name, tuple(translate.get(o, o) for o in obj), type_
4922 )
4923 for key, name, obj, type_ in self._result_columns
4924 ]
4925
4926 text = self._compose_select_body(
4927 text,
4928 select_stmt,
4929 compile_state,
4930 inner_columns,
4931 froms,
4932 byfrom,
4933 toplevel,
4934 kwargs,
4935 )
4936
4937 if select_stmt._statement_hints:
4938 per_dialect = [
4939 ht
4940 for (dialect_name, ht) in select_stmt._statement_hints
4941 if dialect_name in ("*", self.dialect.name)
4942 ]
4943 if per_dialect:
4944 text += " " + self.get_statement_hint_text(per_dialect)
4945
4946 # In compound query, CTEs are shared at the compound level
4947 if self.ctes and (not is_embedded_select or toplevel):
4948 nesting_level = len(self.stack) if not toplevel else None
4949 text = self._render_cte_clause(nesting_level=nesting_level) + text
4950
4951 if select_stmt._suffixes:
4952 text += " " + self._generate_prefixes(
4953 select_stmt, select_stmt._suffixes, **kwargs
4954 )
4955
4956 self.stack.pop(-1)
4957
4958 return text
4959
4960 def _setup_select_hints(
4961 self, select: Select[Any]
4962 ) -> Tuple[str, _FromHintsType]:
4963 byfrom = {
4964 from_: hinttext
4965 % {"name": from_._compiler_dispatch(self, ashint=True)}
4966 for (from_, dialect), hinttext in select._hints.items()
4967 if dialect in ("*", self.dialect.name)
4968 }
4969 hint_text = self.get_select_hint_text(byfrom)
4970 return hint_text, byfrom
4971
4972 def _setup_select_stack(
4973 self, select, compile_state, entry, asfrom, lateral, compound_index
4974 ):
4975 correlate_froms = entry["correlate_froms"]
4976 asfrom_froms = entry["asfrom_froms"]
4977
4978 if compound_index == 0:
4979 entry["select_0"] = select
4980 elif compound_index:
4981 select_0 = entry["select_0"]
4982 numcols = len(select_0._all_selected_columns)
4983
4984 if len(compile_state.columns_plus_names) != numcols:
4985 raise exc.CompileError(
4986 "All selectables passed to "
4987 "CompoundSelect must have identical numbers of "
4988 "columns; select #%d has %d columns, select "
4989 "#%d has %d"
4990 % (
4991 1,
4992 numcols,
4993 compound_index + 1,
4994 len(select._all_selected_columns),
4995 )
4996 )
4997
4998 if asfrom and not lateral:
4999 froms = compile_state._get_display_froms(
5000 explicit_correlate_froms=correlate_froms.difference(
5001 asfrom_froms
5002 ),
5003 implicit_correlate_froms=(),
5004 )
5005 else:
5006 froms = compile_state._get_display_froms(
5007 explicit_correlate_froms=correlate_froms,
5008 implicit_correlate_froms=asfrom_froms,
5009 )
5010
5011 new_correlate_froms = set(_from_objects(*froms))
5012 all_correlate_froms = new_correlate_froms.union(correlate_froms)
5013
5014 new_entry: _CompilerStackEntry = {
5015 "asfrom_froms": new_correlate_froms,
5016 "correlate_froms": all_correlate_froms,
5017 "selectable": select,
5018 "compile_state": compile_state,
5019 }
5020 self.stack.append(new_entry)
5021
5022 return froms
5023
5024 def _compose_select_body(
5025 self,
5026 text,
5027 select,
5028 compile_state,
5029 inner_columns,
5030 froms,
5031 byfrom,
5032 toplevel,
5033 kwargs,
5034 ):
5035 text += ", ".join(inner_columns)
5036
5037 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
5038 from_linter = FromLinter({}, set())
5039 warn_linting = self.linting & WARN_LINTING
5040 if toplevel:
5041 self.from_linter = from_linter
5042 else:
5043 from_linter = None
5044 warn_linting = False
5045
5046 # adjust the whitespace for no inner columns, part of #9440,
5047 # so that a no-col SELECT comes out as "SELECT WHERE..." or
5048 # "SELECT FROM ...".
5049 # while it would be better to have built the SELECT starting string
5050 # without trailing whitespace first, then add whitespace only if inner
5051 # cols were present, this breaks compatibility with various custom
5052 # compilation schemes that are currently being tested.
5053 if not inner_columns:
5054 text = text.rstrip()
5055
5056 if froms:
5057 text += " \nFROM "
5058
5059 if select._hints:
5060 text += ", ".join(
5061 [
5062 f._compiler_dispatch(
5063 self,
5064 asfrom=True,
5065 fromhints=byfrom,
5066 from_linter=from_linter,
5067 **kwargs,
5068 )
5069 for f in froms
5070 ]
5071 )
5072 else:
5073 text += ", ".join(
5074 [
5075 f._compiler_dispatch(
5076 self,
5077 asfrom=True,
5078 from_linter=from_linter,
5079 **kwargs,
5080 )
5081 for f in froms
5082 ]
5083 )
5084 else:
5085 text += self.default_from()
5086
5087 if select._where_criteria:
5088 t = self._generate_delimited_and_list(
5089 select._where_criteria, from_linter=from_linter, **kwargs
5090 )
5091 if t:
5092 text += " \nWHERE " + t
5093
5094 if warn_linting:
5095 assert from_linter is not None
5096 from_linter.warn()
5097
5098 if select._group_by_clauses:
5099 text += self.group_by_clause(select, **kwargs)
5100
5101 if select._having_criteria:
5102 t = self._generate_delimited_and_list(
5103 select._having_criteria, **kwargs
5104 )
5105 if t:
5106 text += " \nHAVING " + t
5107
5108 if select._order_by_clauses:
5109 text += self.order_by_clause(select, **kwargs)
5110
5111 if select._has_row_limiting_clause:
5112 text += self._row_limit_clause(select, **kwargs)
5113
5114 if select._for_update_arg is not None:
5115 text += self.for_update_clause(select, **kwargs)
5116
5117 return text
5118
5119 def _generate_prefixes(self, stmt, prefixes, **kw):
5120 clause = " ".join(
5121 prefix._compiler_dispatch(self, **kw)
5122 for prefix, dialect_name in prefixes
5123 if dialect_name in (None, "*") or dialect_name == self.dialect.name
5124 )
5125 if clause:
5126 clause += " "
5127 return clause
5128
5129 def _render_cte_clause(
5130 self,
5131 nesting_level=None,
5132 include_following_stack=False,
5133 ):
5134 """
5135 include_following_stack
5136 Also render the nesting CTEs on the next stack. Useful for
5137 SQL structures like UNION or INSERT that can wrap SELECT
5138 statements containing nesting CTEs.
5139 """
5140 if not self.ctes:
5141 return ""
5142
5143 ctes: MutableMapping[CTE, str]
5144
5145 if nesting_level and nesting_level > 1:
5146 ctes = util.OrderedDict()
5147 for cte in list(self.ctes.keys()):
5148 cte_level, cte_name, cte_opts = self.level_name_by_cte[
5149 cte._get_reference_cte()
5150 ]
5151 nesting = cte.nesting or cte_opts.nesting
5152 is_rendered_level = cte_level == nesting_level or (
5153 include_following_stack and cte_level == nesting_level + 1
5154 )
5155 if not (nesting and is_rendered_level):
5156 continue
5157
5158 ctes[cte] = self.ctes[cte]
5159
5160 else:
5161 ctes = self.ctes
5162
5163 if not ctes:
5164 return ""
5165 ctes_recursive = any([cte.recursive for cte in ctes])
5166
5167 cte_text = self.get_cte_preamble(ctes_recursive) + " "
5168 cte_text += ", \n".join([txt for txt in ctes.values()])
5169 cte_text += "\n "
5170
5171 if nesting_level and nesting_level > 1:
5172 for cte in list(ctes.keys()):
5173 cte_level, cte_name, cte_opts = self.level_name_by_cte[
5174 cte._get_reference_cte()
5175 ]
5176 del self.ctes[cte]
5177 del self.ctes_by_level_name[(cte_level, cte_name)]
5178 del self.level_name_by_cte[cte._get_reference_cte()]
5179
5180 return cte_text
5181
5182 def get_cte_preamble(self, recursive):
5183 if recursive:
5184 return "WITH RECURSIVE"
5185 else:
5186 return "WITH"
5187
5188 def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str:
5189 """Called when building a ``SELECT`` statement, position is just
5190 before column list.
5191
5192 """
5193 if select._distinct_on:
5194 util.warn_deprecated(
5195 "DISTINCT ON is currently supported only by the PostgreSQL "
5196 "dialect. Use of DISTINCT ON for other backends is currently "
5197 "silently ignored, however this usage is deprecated, and will "
5198 "raise CompileError in a future release for all backends "
5199 "that do not support this syntax.",
5200 version="1.4",
5201 )
5202 return "DISTINCT " if select._distinct else ""
5203
5204 def group_by_clause(self, select, **kw):
5205 """allow dialects to customize how GROUP BY is rendered."""
5206
5207 group_by = self._generate_delimited_list(
5208 select._group_by_clauses, OPERATORS[operators.comma_op], **kw
5209 )
5210 if group_by:
5211 return " GROUP BY " + group_by
5212 else:
5213 return ""
5214
5215 def order_by_clause(self, select, **kw):
5216 """allow dialects to customize how ORDER BY is rendered."""
5217
5218 order_by = self._generate_delimited_list(
5219 select._order_by_clauses, OPERATORS[operators.comma_op], **kw
5220 )
5221
5222 if order_by:
5223 return " ORDER BY " + order_by
5224 else:
5225 return ""
5226
5227 def for_update_clause(self, select, **kw):
5228 return " FOR UPDATE"
5229
5230 def returning_clause(
5231 self,
5232 stmt: UpdateBase,
5233 returning_cols: Sequence[_ColumnsClauseElement],
5234 *,
5235 populate_result_map: bool,
5236 **kw: Any,
5237 ) -> str:
5238 columns = [
5239 self._label_returning_column(
5240 stmt,
5241 column,
5242 populate_result_map,
5243 fallback_label_name=fallback_label_name,
5244 column_is_repeated=repeated,
5245 name=name,
5246 proxy_name=proxy_name,
5247 **kw,
5248 )
5249 for (
5250 name,
5251 proxy_name,
5252 fallback_label_name,
5253 column,
5254 repeated,
5255 ) in stmt._generate_columns_plus_names(
5256 True, cols=base._select_iterables(returning_cols)
5257 )
5258 ]
5259
5260 return "RETURNING " + ", ".join(columns)
5261
5262 def limit_clause(self, select, **kw):
5263 text = ""
5264 if select._limit_clause is not None:
5265 text += "\n LIMIT " + self.process(select._limit_clause, **kw)
5266 if select._offset_clause is not None:
5267 if select._limit_clause is None:
5268 text += "\n LIMIT -1"
5269 text += " OFFSET " + self.process(select._offset_clause, **kw)
5270 return text
5271
5272 def fetch_clause(
5273 self,
5274 select,
5275 fetch_clause=None,
5276 require_offset=False,
5277 use_literal_execute_for_simple_int=False,
5278 **kw,
5279 ):
5280 if fetch_clause is None:
5281 fetch_clause = select._fetch_clause
5282 fetch_clause_options = select._fetch_clause_options
5283 else:
5284 fetch_clause_options = {"percent": False, "with_ties": False}
5285
5286 text = ""
5287
5288 if select._offset_clause is not None:
5289 offset_clause = select._offset_clause
5290 if (
5291 use_literal_execute_for_simple_int
5292 and select._simple_int_clause(offset_clause)
5293 ):
5294 offset_clause = offset_clause.render_literal_execute()
5295 offset_str = self.process(offset_clause, **kw)
5296 text += "\n OFFSET %s ROWS" % offset_str
5297 elif require_offset:
5298 text += "\n OFFSET 0 ROWS"
5299
5300 if fetch_clause is not None:
5301 if (
5302 use_literal_execute_for_simple_int
5303 and select._simple_int_clause(fetch_clause)
5304 ):
5305 fetch_clause = fetch_clause.render_literal_execute()
5306 text += "\n FETCH FIRST %s%s ROWS %s" % (
5307 self.process(fetch_clause, **kw),
5308 " PERCENT" if fetch_clause_options["percent"] else "",
5309 "WITH TIES" if fetch_clause_options["with_ties"] else "ONLY",
5310 )
5311 return text
5312
5313 def visit_table(
5314 self,
5315 table,
5316 asfrom=False,
5317 iscrud=False,
5318 ashint=False,
5319 fromhints=None,
5320 use_schema=True,
5321 from_linter=None,
5322 ambiguous_table_name_map=None,
5323 enclosing_alias=None,
5324 **kwargs,
5325 ):
5326 if from_linter:
5327 from_linter.froms[table] = table.fullname
5328
5329 if asfrom or ashint:
5330 effective_schema = self.preparer.schema_for_object(table)
5331
5332 if use_schema and effective_schema:
5333 ret = (
5334 self.preparer.quote_schema(effective_schema)
5335 + "."
5336 + self.preparer.quote(table.name)
5337 )
5338 else:
5339 ret = self.preparer.quote(table.name)
5340
5341 if (
5342 (
5343 enclosing_alias is None
5344 or enclosing_alias.element is not table
5345 )
5346 and not effective_schema
5347 and ambiguous_table_name_map
5348 and table.name in ambiguous_table_name_map
5349 ):
5350 anon_name = self._truncated_identifier(
5351 "alias", ambiguous_table_name_map[table.name]
5352 )
5353
5354 ret = ret + self.get_render_as_alias_suffix(
5355 self.preparer.format_alias(None, anon_name)
5356 )
5357
5358 if fromhints and table in fromhints:
5359 ret = self.format_from_hint_text(
5360 ret, table, fromhints[table], iscrud
5361 )
5362 return ret
5363 else:
5364 return ""
5365
5366 def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
5367 if from_linter:
5368 from_linter.edges.update(
5369 itertools.product(
5370 _de_clone(join.left._from_objects),
5371 _de_clone(join.right._from_objects),
5372 )
5373 )
5374
5375 if join.full:
5376 join_type = " FULL OUTER JOIN "
5377 elif join.isouter:
5378 join_type = " LEFT OUTER JOIN "
5379 else:
5380 join_type = " JOIN "
5381 return (
5382 join.left._compiler_dispatch(
5383 self, asfrom=True, from_linter=from_linter, **kwargs
5384 )
5385 + join_type
5386 + join.right._compiler_dispatch(
5387 self, asfrom=True, from_linter=from_linter, **kwargs
5388 )
5389 + " ON "
5390 # TODO: likely need asfrom=True here?
5391 + join.onclause._compiler_dispatch(
5392 self, from_linter=from_linter, **kwargs
5393 )
5394 )
5395
5396 def _setup_crud_hints(self, stmt, table_text):
5397 dialect_hints = {
5398 table: hint_text
5399 for (table, dialect), hint_text in stmt._hints.items()
5400 if dialect in ("*", self.dialect.name)
5401 }
5402 if stmt.table in dialect_hints:
5403 table_text = self.format_from_hint_text(
5404 table_text, stmt.table, dialect_hints[stmt.table], True
5405 )
5406 return dialect_hints, table_text
5407
5408 # within the realm of "insertmanyvalues sentinel columns",
5409 # these lookups match different kinds of Column() configurations
5410 # to specific backend capabilities. they are broken into two
5411 # lookups, one for autoincrement columns and the other for non
5412 # autoincrement columns
5413 _sentinel_col_non_autoinc_lookup = util.immutabledict(
5414 {
5415 _SentinelDefaultCharacterization.CLIENTSIDE: (
5416 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5417 ),
5418 _SentinelDefaultCharacterization.SENTINEL_DEFAULT: (
5419 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5420 ),
5421 _SentinelDefaultCharacterization.NONE: (
5422 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5423 ),
5424 _SentinelDefaultCharacterization.IDENTITY: (
5425 InsertmanyvaluesSentinelOpts.IDENTITY
5426 ),
5427 _SentinelDefaultCharacterization.SEQUENCE: (
5428 InsertmanyvaluesSentinelOpts.SEQUENCE
5429 ),
5430 }
5431 )
5432 _sentinel_col_autoinc_lookup = _sentinel_col_non_autoinc_lookup.union(
5433 {
5434 _SentinelDefaultCharacterization.NONE: (
5435 InsertmanyvaluesSentinelOpts.AUTOINCREMENT
5436 ),
5437 }
5438 )
5439
5440 def _get_sentinel_column_for_table(
5441 self, table: Table
5442 ) -> Optional[Sequence[Column[Any]]]:
5443 """given a :class:`.Table`, return a usable sentinel column or
5444 columns for this dialect if any.
5445
5446 Return None if no sentinel columns could be identified, or raise an
5447 error if a column was marked as a sentinel explicitly but isn't
5448 compatible with this dialect.
5449
5450 """
5451
5452 sentinel_opts = self.dialect.insertmanyvalues_implicit_sentinel
5453 sentinel_characteristics = table._sentinel_column_characteristics
5454
5455 sent_cols = sentinel_characteristics.columns
5456
5457 if sent_cols is None:
5458 return None
5459
5460 if sentinel_characteristics.is_autoinc:
5461 bitmask = self._sentinel_col_autoinc_lookup.get(
5462 sentinel_characteristics.default_characterization, 0
5463 )
5464 else:
5465 bitmask = self._sentinel_col_non_autoinc_lookup.get(
5466 sentinel_characteristics.default_characterization, 0
5467 )
5468
5469 if sentinel_opts & bitmask:
5470 return sent_cols
5471
5472 if sentinel_characteristics.is_explicit:
5473 # a column was explicitly marked as insert_sentinel=True,
5474 # however it is not compatible with this dialect. they should
5475 # not indicate this column as a sentinel if they need to include
5476 # this dialect.
5477
5478 # TODO: do we want non-primary key explicit sentinel cols
5479 # that can gracefully degrade for some backends?
5480 # insert_sentinel="degrade" perhaps. not for the initial release.
5481 # I am hoping people are generally not dealing with this sentinel
5482 # business at all.
5483
5484 # if is_explicit is True, there will be only one sentinel column.
5485
5486 raise exc.InvalidRequestError(
5487 f"Column {sent_cols[0]} can't be explicitly "
5488 "marked as a sentinel column when using the "
5489 f"{self.dialect.name} dialect, as the "
5490 "particular type of default generation on this column is "
5491 "not currently compatible with this dialect's specific "
5492 f"INSERT..RETURNING syntax which can receive the "
5493 "server-generated value in "
5494 "a deterministic way. To remove this error, remove "
5495 "insert_sentinel=True from primary key autoincrement "
5496 "columns; these columns are automatically used as "
5497 "sentinels for supported dialects in any case."
5498 )
5499
5500 return None
5501
5502 def _deliver_insertmanyvalues_batches(
5503 self,
5504 statement: str,
5505 parameters: _DBAPIMultiExecuteParams,
5506 compiled_parameters: List[_MutableCoreSingleExecuteParams],
5507 generic_setinputsizes: Optional[_GenericSetInputSizesType],
5508 batch_size: int,
5509 sort_by_parameter_order: bool,
5510 schema_translate_map: Optional[SchemaTranslateMapType],
5511 ) -> Iterator[_InsertManyValuesBatch]:
5512 imv = self._insertmanyvalues
5513 assert imv is not None
5514
5515 if not imv.sentinel_param_keys:
5516 _sentinel_from_params = None
5517 else:
5518 _sentinel_from_params = operator.itemgetter(
5519 *imv.sentinel_param_keys
5520 )
5521
5522 lenparams = len(parameters)
5523 if imv.is_default_expr and not self.dialect.supports_default_metavalue:
5524 # backend doesn't support
5525 # INSERT INTO table (pk_col) VALUES (DEFAULT), (DEFAULT), ...
5526 # at the moment this is basically SQL Server due to
5527 # not being able to use DEFAULT for identity column
5528 # just yield out that many single statements! still
5529 # faster than a whole connection.execute() call ;)
5530 #
5531 # note we still are taking advantage of the fact that we know
5532 # we are using RETURNING. The generalized approach of fetching
5533 # cursor.lastrowid etc. still goes through the more heavyweight
5534 # "ExecutionContext per statement" system as it isn't usable
5535 # as a generic "RETURNING" approach
5536 use_row_at_a_time = True
5537 downgraded = False
5538 elif not self.dialect.supports_multivalues_insert or (
5539 sort_by_parameter_order
5540 and self._result_columns
5541 and (imv.sentinel_columns is None or imv.includes_upsert_behaviors)
5542 ):
5543 # deterministic order was requested and the compiler could
5544 # not organize sentinel columns for this dialect/statement.
5545 # use row at a time
5546 use_row_at_a_time = True
5547 downgraded = True
5548 else:
5549 use_row_at_a_time = False
5550 downgraded = False
5551
5552 if use_row_at_a_time:
5553 for batchnum, (param, compiled_param) in enumerate(
5554 cast(
5555 "Sequence[Tuple[_DBAPISingleExecuteParams, _MutableCoreSingleExecuteParams]]", # noqa: E501
5556 zip(parameters, compiled_parameters),
5557 ),
5558 1,
5559 ):
5560 yield _InsertManyValuesBatch(
5561 statement,
5562 param,
5563 generic_setinputsizes,
5564 [param],
5565 (
5566 [_sentinel_from_params(compiled_param)]
5567 if _sentinel_from_params
5568 else []
5569 ),
5570 1,
5571 batchnum,
5572 lenparams,
5573 sort_by_parameter_order,
5574 downgraded,
5575 )
5576 return
5577
5578 if schema_translate_map:
5579 rst = functools.partial(
5580 self.preparer._render_schema_translates,
5581 schema_translate_map=schema_translate_map,
5582 )
5583 else:
5584 rst = None
5585
5586 imv_single_values_expr = imv.single_values_expr
5587 if rst:
5588 imv_single_values_expr = rst(imv_single_values_expr)
5589
5590 executemany_values = f"({imv_single_values_expr})"
5591 statement = statement.replace(executemany_values, "__EXECMANY_TOKEN__")
5592
5593 # Use optional insertmanyvalues_max_parameters
5594 # to further shrink the batch size so that there are no more than
5595 # insertmanyvalues_max_parameters params.
5596 # Currently used by SQL Server, which limits statements to 2100 bound
5597 # parameters (actually 2099).
5598 max_params = self.dialect.insertmanyvalues_max_parameters
5599 if max_params:
5600 total_num_of_params = len(self.bind_names)
5601 num_params_per_batch = len(imv.insert_crud_params)
5602 num_params_outside_of_batch = (
5603 total_num_of_params - num_params_per_batch
5604 )
5605 batch_size = min(
5606 batch_size,
5607 (
5608 (max_params - num_params_outside_of_batch)
5609 // num_params_per_batch
5610 ),
5611 )
5612
5613 batches = cast("List[Sequence[Any]]", list(parameters))
5614 compiled_batches = cast(
5615 "List[Sequence[Any]]", list(compiled_parameters)
5616 )
5617
5618 processed_setinputsizes: Optional[_GenericSetInputSizesType] = None
5619 batchnum = 1
5620 total_batches = lenparams // batch_size + (
5621 1 if lenparams % batch_size else 0
5622 )
5623
5624 insert_crud_params = imv.insert_crud_params
5625 assert insert_crud_params is not None
5626
5627 if rst:
5628 insert_crud_params = [
5629 (col, key, rst(expr), st)
5630 for col, key, expr, st in insert_crud_params
5631 ]
5632
5633 escaped_bind_names: Mapping[str, str]
5634 expand_pos_lower_index = expand_pos_upper_index = 0
5635
5636 if not self.positional:
5637 if self.escaped_bind_names:
5638 escaped_bind_names = self.escaped_bind_names
5639 else:
5640 escaped_bind_names = {}
5641
5642 all_keys = set(parameters[0])
5643
5644 def apply_placeholders(keys, formatted):
5645 for key in keys:
5646 key = escaped_bind_names.get(key, key)
5647 formatted = formatted.replace(
5648 self.bindtemplate % {"name": key},
5649 self.bindtemplate
5650 % {"name": f"{key}__EXECMANY_INDEX__"},
5651 )
5652 return formatted
5653
5654 if imv.embed_values_counter:
5655 imv_values_counter = ", _IMV_VALUES_COUNTER"
5656 else:
5657 imv_values_counter = ""
5658 formatted_values_clause = f"""({', '.join(
5659 apply_placeholders(bind_keys, formatted)
5660 for _, _, formatted, bind_keys in insert_crud_params
5661 )}{imv_values_counter})"""
5662
5663 keys_to_replace = all_keys.intersection(
5664 escaped_bind_names.get(key, key)
5665 for _, _, _, bind_keys in insert_crud_params
5666 for key in bind_keys
5667 )
5668 base_parameters = {
5669 key: parameters[0][key]
5670 for key in all_keys.difference(keys_to_replace)
5671 }
5672 executemany_values_w_comma = ""
5673 else:
5674 formatted_values_clause = ""
5675 keys_to_replace = set()
5676 base_parameters = {}
5677
5678 if imv.embed_values_counter:
5679 executemany_values_w_comma = (
5680 f"({imv_single_values_expr}, _IMV_VALUES_COUNTER), "
5681 )
5682 else:
5683 executemany_values_w_comma = f"({imv_single_values_expr}), "
5684
5685 all_names_we_will_expand: Set[str] = set()
5686 for elem in imv.insert_crud_params:
5687 all_names_we_will_expand.update(elem[3])
5688
5689 # get the start and end position in a particular list
5690 # of parameters where we will be doing the "expanding".
5691 # statements can have params on either side or both sides,
5692 # given RETURNING and CTEs
5693 if all_names_we_will_expand:
5694 positiontup = self.positiontup
5695 assert positiontup is not None
5696
5697 all_expand_positions = {
5698 idx
5699 for idx, name in enumerate(positiontup)
5700 if name in all_names_we_will_expand
5701 }
5702 expand_pos_lower_index = min(all_expand_positions)
5703 expand_pos_upper_index = max(all_expand_positions) + 1
5704 assert (
5705 len(all_expand_positions)
5706 == expand_pos_upper_index - expand_pos_lower_index
5707 )
5708
5709 if self._numeric_binds:
5710 escaped = re.escape(self._numeric_binds_identifier_char)
5711 executemany_values_w_comma = re.sub(
5712 rf"{escaped}\d+", "%s", executemany_values_w_comma
5713 )
5714
5715 while batches:
5716 batch = batches[0:batch_size]
5717 compiled_batch = compiled_batches[0:batch_size]
5718
5719 batches[0:batch_size] = []
5720 compiled_batches[0:batch_size] = []
5721
5722 if batches:
5723 current_batch_size = batch_size
5724 else:
5725 current_batch_size = len(batch)
5726
5727 if generic_setinputsizes:
5728 # if setinputsizes is present, expand this collection to
5729 # suit the batch length as well
5730 # currently this will be mssql+pyodbc for internal dialects
5731 processed_setinputsizes = [
5732 (new_key, len_, typ)
5733 for new_key, len_, typ in (
5734 (f"{key}_{index}", len_, typ)
5735 for index in range(current_batch_size)
5736 for key, len_, typ in generic_setinputsizes
5737 )
5738 ]
5739
5740 replaced_parameters: Any
5741 if self.positional:
5742 num_ins_params = imv.num_positional_params_counted
5743
5744 batch_iterator: Iterable[Sequence[Any]]
5745 extra_params_left: Sequence[Any]
5746 extra_params_right: Sequence[Any]
5747
5748 if num_ins_params == len(batch[0]):
5749 extra_params_left = extra_params_right = ()
5750 batch_iterator = batch
5751 else:
5752 extra_params_left = batch[0][:expand_pos_lower_index]
5753 extra_params_right = batch[0][expand_pos_upper_index:]
5754 batch_iterator = (
5755 b[expand_pos_lower_index:expand_pos_upper_index]
5756 for b in batch
5757 )
5758
5759 if imv.embed_values_counter:
5760 expanded_values_string = (
5761 "".join(
5762 executemany_values_w_comma.replace(
5763 "_IMV_VALUES_COUNTER", str(i)
5764 )
5765 for i, _ in enumerate(batch)
5766 )
5767 )[:-2]
5768 else:
5769 expanded_values_string = (
5770 (executemany_values_w_comma * current_batch_size)
5771 )[:-2]
5772
5773 if self._numeric_binds and num_ins_params > 0:
5774 # numeric will always number the parameters inside of
5775 # VALUES (and thus order self.positiontup) to be higher
5776 # than non-VALUES parameters, no matter where in the
5777 # statement those non-VALUES parameters appear (this is
5778 # ensured in _process_numeric by numbering first all
5779 # params that are not in _values_bindparam)
5780 # therefore all extra params are always
5781 # on the left side and numbered lower than the VALUES
5782 # parameters
5783 assert not extra_params_right
5784
5785 start = expand_pos_lower_index + 1
5786 end = num_ins_params * (current_batch_size) + start
5787
5788 # need to format here, since statement may contain
5789 # unescaped %, while values_string contains just (%s, %s)
5790 positions = tuple(
5791 f"{self._numeric_binds_identifier_char}{i}"
5792 for i in range(start, end)
5793 )
5794 expanded_values_string = expanded_values_string % positions
5795
5796 replaced_statement = statement.replace(
5797 "__EXECMANY_TOKEN__", expanded_values_string
5798 )
5799
5800 replaced_parameters = tuple(
5801 itertools.chain.from_iterable(batch_iterator)
5802 )
5803
5804 replaced_parameters = (
5805 extra_params_left
5806 + replaced_parameters
5807 + extra_params_right
5808 )
5809
5810 else:
5811 replaced_values_clauses = []
5812 replaced_parameters = base_parameters.copy()
5813
5814 for i, param in enumerate(batch):
5815 fmv = formatted_values_clause.replace(
5816 "EXECMANY_INDEX__", str(i)
5817 )
5818 if imv.embed_values_counter:
5819 fmv = fmv.replace("_IMV_VALUES_COUNTER", str(i))
5820
5821 replaced_values_clauses.append(fmv)
5822 replaced_parameters.update(
5823 {f"{key}__{i}": param[key] for key in keys_to_replace}
5824 )
5825
5826 replaced_statement = statement.replace(
5827 "__EXECMANY_TOKEN__",
5828 ", ".join(replaced_values_clauses),
5829 )
5830
5831 yield _InsertManyValuesBatch(
5832 replaced_statement,
5833 replaced_parameters,
5834 processed_setinputsizes,
5835 batch,
5836 (
5837 [_sentinel_from_params(cb) for cb in compiled_batch]
5838 if _sentinel_from_params
5839 else []
5840 ),
5841 current_batch_size,
5842 batchnum,
5843 total_batches,
5844 sort_by_parameter_order,
5845 False,
5846 )
5847 batchnum += 1
5848
5849 def visit_insert(
5850 self, insert_stmt, visited_bindparam=None, visiting_cte=None, **kw
5851 ):
5852 compile_state = insert_stmt._compile_state_factory(
5853 insert_stmt, self, **kw
5854 )
5855 insert_stmt = compile_state.statement
5856
5857 if visiting_cte is not None:
5858 kw["visiting_cte"] = visiting_cte
5859 toplevel = False
5860 else:
5861 toplevel = not self.stack
5862
5863 if toplevel:
5864 self.isinsert = True
5865 if not self.dml_compile_state:
5866 self.dml_compile_state = compile_state
5867 if not self.compile_state:
5868 self.compile_state = compile_state
5869
5870 self.stack.append(
5871 {
5872 "correlate_froms": set(),
5873 "asfrom_froms": set(),
5874 "selectable": insert_stmt,
5875 }
5876 )
5877
5878 counted_bindparam = 0
5879
5880 # reset any incoming "visited_bindparam" collection
5881 visited_bindparam = None
5882
5883 # for positional, insertmanyvalues needs to know how many
5884 # bound parameters are in the VALUES sequence; there's no simple
5885 # rule because default expressions etc. can have zero or more
5886 # params inside them. After multiple attempts to figure this out,
5887 # this very simplistic "count after" works and is
5888 # likely the least amount of callcounts, though looks clumsy
5889 if self.positional and visiting_cte is None:
5890 # if we are inside a CTE, don't count parameters
5891 # here since they wont be for insertmanyvalues. keep
5892 # visited_bindparam at None so no counting happens.
5893 # see #9173
5894 visited_bindparam = []
5895
5896 crud_params_struct = crud._get_crud_params(
5897 self,
5898 insert_stmt,
5899 compile_state,
5900 toplevel,
5901 visited_bindparam=visited_bindparam,
5902 **kw,
5903 )
5904
5905 if self.positional and visited_bindparam is not None:
5906 counted_bindparam = len(visited_bindparam)
5907 if self._numeric_binds:
5908 if self._values_bindparam is not None:
5909 self._values_bindparam += visited_bindparam
5910 else:
5911 self._values_bindparam = visited_bindparam
5912
5913 crud_params_single = crud_params_struct.single_params
5914
5915 if (
5916 not crud_params_single
5917 and not self.dialect.supports_default_values
5918 and not self.dialect.supports_default_metavalue
5919 and not self.dialect.supports_empty_insert
5920 ):
5921 raise exc.CompileError(
5922 "The '%s' dialect with current database "
5923 "version settings does not support empty "
5924 "inserts." % self.dialect.name
5925 )
5926
5927 if compile_state._has_multi_parameters:
5928 if not self.dialect.supports_multivalues_insert:
5929 raise exc.CompileError(
5930 "The '%s' dialect with current database "
5931 "version settings does not support "
5932 "in-place multirow inserts." % self.dialect.name
5933 )
5934 elif (
5935 self.implicit_returning or insert_stmt._returning
5936 ) and insert_stmt._sort_by_parameter_order:
5937 raise exc.CompileError(
5938 "RETURNING cannot be determinstically sorted when "
5939 "using an INSERT which includes multi-row values()."
5940 )
5941 crud_params_single = crud_params_struct.single_params
5942 else:
5943 crud_params_single = crud_params_struct.single_params
5944
5945 preparer = self.preparer
5946 supports_default_values = self.dialect.supports_default_values
5947
5948 text = "INSERT "
5949
5950 if insert_stmt._prefixes:
5951 text += self._generate_prefixes(
5952 insert_stmt, insert_stmt._prefixes, **kw
5953 )
5954
5955 text += "INTO "
5956 table_text = preparer.format_table(insert_stmt.table)
5957
5958 if insert_stmt._hints:
5959 _, table_text = self._setup_crud_hints(insert_stmt, table_text)
5960
5961 if insert_stmt._independent_ctes:
5962 self._dispatch_independent_ctes(insert_stmt, kw)
5963
5964 text += table_text
5965
5966 if crud_params_single or not supports_default_values:
5967 text += " (%s)" % ", ".join(
5968 [expr for _, expr, _, _ in crud_params_single]
5969 )
5970
5971 # look for insertmanyvalues attributes that would have been configured
5972 # by crud.py as it scanned through the columns to be part of the
5973 # INSERT
5974 use_insertmanyvalues = crud_params_struct.use_insertmanyvalues
5975 named_sentinel_params: Optional[Sequence[str]] = None
5976 add_sentinel_cols = None
5977 implicit_sentinel = False
5978
5979 returning_cols = self.implicit_returning or insert_stmt._returning
5980 if returning_cols:
5981 add_sentinel_cols = crud_params_struct.use_sentinel_columns
5982 if add_sentinel_cols is not None:
5983 assert use_insertmanyvalues
5984
5985 # search for the sentinel column explicitly present
5986 # in the INSERT columns list, and additionally check that
5987 # this column has a bound parameter name set up that's in the
5988 # parameter list. If both of these cases are present, it means
5989 # we will have a client side value for the sentinel in each
5990 # parameter set.
5991
5992 _params_by_col = {
5993 col: param_names
5994 for col, _, _, param_names in crud_params_single
5995 }
5996 named_sentinel_params = []
5997 for _add_sentinel_col in add_sentinel_cols:
5998 if _add_sentinel_col not in _params_by_col:
5999 named_sentinel_params = None
6000 break
6001 param_name = self._within_exec_param_key_getter(
6002 _add_sentinel_col
6003 )
6004 if param_name not in _params_by_col[_add_sentinel_col]:
6005 named_sentinel_params = None
6006 break
6007 named_sentinel_params.append(param_name)
6008
6009 if named_sentinel_params is None:
6010 # if we are not going to have a client side value for
6011 # the sentinel in the parameter set, that means it's
6012 # an autoincrement, an IDENTITY, or a server-side SQL
6013 # expression like nextval('seqname'). So this is
6014 # an "implicit" sentinel; we will look for it in
6015 # RETURNING
6016 # only, and then sort on it. For this case on PG,
6017 # SQL Server we have to use a special INSERT form
6018 # that guarantees the server side function lines up with
6019 # the entries in the VALUES.
6020 if (
6021 self.dialect.insertmanyvalues_implicit_sentinel
6022 & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT
6023 ):
6024 implicit_sentinel = True
6025 else:
6026 # here, we are not using a sentinel at all
6027 # and we are likely the SQLite dialect.
6028 # The first add_sentinel_col that we have should not
6029 # be marked as "insert_sentinel=True". if it was,
6030 # an error should have been raised in
6031 # _get_sentinel_column_for_table.
6032 assert not add_sentinel_cols[0]._insert_sentinel, (
6033 "sentinel selection rules should have prevented "
6034 "us from getting here for this dialect"
6035 )
6036
6037 # always put the sentinel columns last. even if they are
6038 # in the returning list already, they will be there twice
6039 # then.
6040 returning_cols = list(returning_cols) + list(add_sentinel_cols)
6041
6042 returning_clause = self.returning_clause(
6043 insert_stmt,
6044 returning_cols,
6045 populate_result_map=toplevel,
6046 )
6047
6048 if self.returning_precedes_values:
6049 text += " " + returning_clause
6050
6051 else:
6052 returning_clause = None
6053
6054 if insert_stmt.select is not None:
6055 # placed here by crud.py
6056 select_text = self.process(
6057 self.stack[-1]["insert_from_select"], insert_into=True, **kw
6058 )
6059
6060 if self.ctes and self.dialect.cte_follows_insert:
6061 nesting_level = len(self.stack) if not toplevel else None
6062 text += " %s%s" % (
6063 self._render_cte_clause(
6064 nesting_level=nesting_level,
6065 include_following_stack=True,
6066 ),
6067 select_text,
6068 )
6069 else:
6070 text += " %s" % select_text
6071 elif not crud_params_single and supports_default_values:
6072 text += " DEFAULT VALUES"
6073 if use_insertmanyvalues:
6074 self._insertmanyvalues = _InsertManyValues(
6075 True,
6076 self.dialect.default_metavalue_token,
6077 cast(
6078 "List[crud._CrudParamElementStr]", crud_params_single
6079 ),
6080 counted_bindparam,
6081 sort_by_parameter_order=(
6082 insert_stmt._sort_by_parameter_order
6083 ),
6084 includes_upsert_behaviors=(
6085 insert_stmt._post_values_clause is not None
6086 ),
6087 sentinel_columns=add_sentinel_cols,
6088 num_sentinel_columns=(
6089 len(add_sentinel_cols) if add_sentinel_cols else 0
6090 ),
6091 implicit_sentinel=implicit_sentinel,
6092 )
6093 elif compile_state._has_multi_parameters:
6094 text += " VALUES %s" % (
6095 ", ".join(
6096 "(%s)"
6097 % (", ".join(value for _, _, value, _ in crud_param_set))
6098 for crud_param_set in crud_params_struct.all_multi_params
6099 ),
6100 )
6101 else:
6102 insert_single_values_expr = ", ".join(
6103 [
6104 value
6105 for _, _, value, _ in cast(
6106 "List[crud._CrudParamElementStr]",
6107 crud_params_single,
6108 )
6109 ]
6110 )
6111
6112 if use_insertmanyvalues:
6113 if (
6114 implicit_sentinel
6115 and (
6116 self.dialect.insertmanyvalues_implicit_sentinel
6117 & InsertmanyvaluesSentinelOpts.USE_INSERT_FROM_SELECT
6118 )
6119 # this is checking if we have
6120 # INSERT INTO table (id) VALUES (DEFAULT).
6121 and not (crud_params_struct.is_default_metavalue_only)
6122 ):
6123 # if we have a sentinel column that is server generated,
6124 # then for selected backends render the VALUES list as a
6125 # subquery. This is the orderable form supported by
6126 # PostgreSQL and SQL Server.
6127 embed_sentinel_value = True
6128
6129 render_bind_casts = (
6130 self.dialect.insertmanyvalues_implicit_sentinel
6131 & InsertmanyvaluesSentinelOpts.RENDER_SELECT_COL_CASTS
6132 )
6133
6134 colnames = ", ".join(
6135 f"p{i}" for i, _ in enumerate(crud_params_single)
6136 )
6137
6138 if render_bind_casts:
6139 # render casts for the SELECT list. For PG, we are
6140 # already rendering bind casts in the parameter list,
6141 # selectively for the more "tricky" types like ARRAY.
6142 # however, even for the "easy" types, if the parameter
6143 # is NULL for every entry, PG gives up and says
6144 # "it must be TEXT", which fails for other easy types
6145 # like ints. So we cast on this side too.
6146 colnames_w_cast = ", ".join(
6147 self.render_bind_cast(
6148 col.type,
6149 col.type._unwrapped_dialect_impl(self.dialect),
6150 f"p{i}",
6151 )
6152 for i, (col, *_) in enumerate(crud_params_single)
6153 )
6154 else:
6155 colnames_w_cast = colnames
6156
6157 text += (
6158 f" SELECT {colnames_w_cast} FROM "
6159 f"(VALUES ({insert_single_values_expr})) "
6160 f"AS imp_sen({colnames}, sen_counter) "
6161 "ORDER BY sen_counter"
6162 )
6163 else:
6164 # otherwise, if no sentinel or backend doesn't support
6165 # orderable subquery form, use a plain VALUES list
6166 embed_sentinel_value = False
6167 text += f" VALUES ({insert_single_values_expr})"
6168
6169 self._insertmanyvalues = _InsertManyValues(
6170 is_default_expr=False,
6171 single_values_expr=insert_single_values_expr,
6172 insert_crud_params=cast(
6173 "List[crud._CrudParamElementStr]",
6174 crud_params_single,
6175 ),
6176 num_positional_params_counted=counted_bindparam,
6177 sort_by_parameter_order=(
6178 insert_stmt._sort_by_parameter_order
6179 ),
6180 includes_upsert_behaviors=(
6181 insert_stmt._post_values_clause is not None
6182 ),
6183 sentinel_columns=add_sentinel_cols,
6184 num_sentinel_columns=(
6185 len(add_sentinel_cols) if add_sentinel_cols else 0
6186 ),
6187 sentinel_param_keys=named_sentinel_params,
6188 implicit_sentinel=implicit_sentinel,
6189 embed_values_counter=embed_sentinel_value,
6190 )
6191
6192 else:
6193 text += f" VALUES ({insert_single_values_expr})"
6194
6195 if insert_stmt._post_values_clause is not None:
6196 post_values_clause = self.process(
6197 insert_stmt._post_values_clause, **kw
6198 )
6199 if post_values_clause:
6200 text += " " + post_values_clause
6201
6202 if returning_clause and not self.returning_precedes_values:
6203 text += " " + returning_clause
6204
6205 if self.ctes and not self.dialect.cte_follows_insert:
6206 nesting_level = len(self.stack) if not toplevel else None
6207 text = (
6208 self._render_cte_clause(
6209 nesting_level=nesting_level,
6210 include_following_stack=True,
6211 )
6212 + text
6213 )
6214
6215 self.stack.pop(-1)
6216
6217 return text
6218
6219 def update_limit_clause(self, update_stmt):
6220 """Provide a hook for MySQL to add LIMIT to the UPDATE"""
6221 return None
6222
6223 def delete_limit_clause(self, delete_stmt):
6224 """Provide a hook for MySQL to add LIMIT to the DELETE"""
6225 return None
6226
6227 def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
6228 """Provide a hook to override the initial table clause
6229 in an UPDATE statement.
6230
6231 MySQL overrides this.
6232
6233 """
6234 kw["asfrom"] = True
6235 return from_table._compiler_dispatch(self, iscrud=True, **kw)
6236
6237 def update_from_clause(
6238 self, update_stmt, from_table, extra_froms, from_hints, **kw
6239 ):
6240 """Provide a hook to override the generation of an
6241 UPDATE..FROM clause.
6242
6243 MySQL and MSSQL override this.
6244
6245 """
6246 raise NotImplementedError(
6247 "This backend does not support multiple-table "
6248 "criteria within UPDATE"
6249 )
6250
6251 def visit_update(
6252 self,
6253 update_stmt: Update,
6254 visiting_cte: Optional[CTE] = None,
6255 **kw: Any,
6256 ) -> str:
6257 compile_state = update_stmt._compile_state_factory(
6258 update_stmt, self, **kw
6259 )
6260 if TYPE_CHECKING:
6261 assert isinstance(compile_state, UpdateDMLState)
6262 update_stmt = compile_state.statement # type: ignore[assignment]
6263
6264 if visiting_cte is not None:
6265 kw["visiting_cte"] = visiting_cte
6266 toplevel = False
6267 else:
6268 toplevel = not self.stack
6269
6270 if toplevel:
6271 self.isupdate = True
6272 if not self.dml_compile_state:
6273 self.dml_compile_state = compile_state
6274 if not self.compile_state:
6275 self.compile_state = compile_state
6276
6277 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
6278 from_linter = FromLinter({}, set())
6279 warn_linting = self.linting & WARN_LINTING
6280 if toplevel:
6281 self.from_linter = from_linter
6282 else:
6283 from_linter = None
6284 warn_linting = False
6285
6286 extra_froms = compile_state._extra_froms
6287 is_multitable = bool(extra_froms)
6288
6289 if is_multitable:
6290 # main table might be a JOIN
6291 main_froms = set(_from_objects(update_stmt.table))
6292 render_extra_froms = [
6293 f for f in extra_froms if f not in main_froms
6294 ]
6295 correlate_froms = main_froms.union(extra_froms)
6296 else:
6297 render_extra_froms = []
6298 correlate_froms = {update_stmt.table}
6299
6300 self.stack.append(
6301 {
6302 "correlate_froms": correlate_froms,
6303 "asfrom_froms": correlate_froms,
6304 "selectable": update_stmt,
6305 }
6306 )
6307
6308 text = "UPDATE "
6309
6310 if update_stmt._prefixes:
6311 text += self._generate_prefixes(
6312 update_stmt, update_stmt._prefixes, **kw
6313 )
6314
6315 table_text = self.update_tables_clause(
6316 update_stmt,
6317 update_stmt.table,
6318 render_extra_froms,
6319 from_linter=from_linter,
6320 **kw,
6321 )
6322 crud_params_struct = crud._get_crud_params(
6323 self, update_stmt, compile_state, toplevel, **kw
6324 )
6325 crud_params = crud_params_struct.single_params
6326
6327 if update_stmt._hints:
6328 dialect_hints, table_text = self._setup_crud_hints(
6329 update_stmt, table_text
6330 )
6331 else:
6332 dialect_hints = None
6333
6334 if update_stmt._independent_ctes:
6335 self._dispatch_independent_ctes(update_stmt, kw)
6336
6337 text += table_text
6338
6339 text += " SET "
6340 text += ", ".join(
6341 expr + "=" + value
6342 for _, expr, value, _ in cast(
6343 "List[Tuple[Any, str, str, Any]]", crud_params
6344 )
6345 )
6346
6347 if self.implicit_returning or update_stmt._returning:
6348 if self.returning_precedes_values:
6349 text += " " + self.returning_clause(
6350 update_stmt,
6351 self.implicit_returning or update_stmt._returning,
6352 populate_result_map=toplevel,
6353 )
6354
6355 if extra_froms:
6356 extra_from_text = self.update_from_clause(
6357 update_stmt,
6358 update_stmt.table,
6359 render_extra_froms,
6360 dialect_hints,
6361 from_linter=from_linter,
6362 **kw,
6363 )
6364 if extra_from_text:
6365 text += " " + extra_from_text
6366
6367 if update_stmt._where_criteria:
6368 t = self._generate_delimited_and_list(
6369 update_stmt._where_criteria, from_linter=from_linter, **kw
6370 )
6371 if t:
6372 text += " WHERE " + t
6373
6374 limit_clause = self.update_limit_clause(update_stmt)
6375 if limit_clause:
6376 text += " " + limit_clause
6377
6378 if (
6379 self.implicit_returning or update_stmt._returning
6380 ) and not self.returning_precedes_values:
6381 text += " " + self.returning_clause(
6382 update_stmt,
6383 self.implicit_returning or update_stmt._returning,
6384 populate_result_map=toplevel,
6385 )
6386
6387 if self.ctes:
6388 nesting_level = len(self.stack) if not toplevel else None
6389 text = self._render_cte_clause(nesting_level=nesting_level) + text
6390
6391 if warn_linting:
6392 assert from_linter is not None
6393 from_linter.warn(stmt_type="UPDATE")
6394
6395 self.stack.pop(-1)
6396
6397 return text # type: ignore[no-any-return]
6398
6399 def delete_extra_from_clause(
6400 self, delete_stmt, from_table, extra_froms, from_hints, **kw
6401 ):
6402 """Provide a hook to override the generation of an
6403 DELETE..FROM clause.
6404
6405 This can be used to implement DELETE..USING for example.
6406
6407 MySQL and MSSQL override this.
6408
6409 """
6410 raise NotImplementedError(
6411 "This backend does not support multiple-table "
6412 "criteria within DELETE"
6413 )
6414
6415 def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
6416 return from_table._compiler_dispatch(
6417 self, asfrom=True, iscrud=True, **kw
6418 )
6419
6420 def visit_delete(self, delete_stmt, visiting_cte=None, **kw):
6421 compile_state = delete_stmt._compile_state_factory(
6422 delete_stmt, self, **kw
6423 )
6424 delete_stmt = compile_state.statement
6425
6426 if visiting_cte is not None:
6427 kw["visiting_cte"] = visiting_cte
6428 toplevel = False
6429 else:
6430 toplevel = not self.stack
6431
6432 if toplevel:
6433 self.isdelete = True
6434 if not self.dml_compile_state:
6435 self.dml_compile_state = compile_state
6436 if not self.compile_state:
6437 self.compile_state = compile_state
6438
6439 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
6440 from_linter = FromLinter({}, set())
6441 warn_linting = self.linting & WARN_LINTING
6442 if toplevel:
6443 self.from_linter = from_linter
6444 else:
6445 from_linter = None
6446 warn_linting = False
6447
6448 extra_froms = compile_state._extra_froms
6449
6450 correlate_froms = {delete_stmt.table}.union(extra_froms)
6451 self.stack.append(
6452 {
6453 "correlate_froms": correlate_froms,
6454 "asfrom_froms": correlate_froms,
6455 "selectable": delete_stmt,
6456 }
6457 )
6458
6459 text = "DELETE "
6460
6461 if delete_stmt._prefixes:
6462 text += self._generate_prefixes(
6463 delete_stmt, delete_stmt._prefixes, **kw
6464 )
6465
6466 text += "FROM "
6467
6468 try:
6469 table_text = self.delete_table_clause(
6470 delete_stmt,
6471 delete_stmt.table,
6472 extra_froms,
6473 from_linter=from_linter,
6474 )
6475 except TypeError:
6476 # anticipate 3rd party dialects that don't include **kw
6477 # TODO: remove in 2.1
6478 table_text = self.delete_table_clause(
6479 delete_stmt, delete_stmt.table, extra_froms
6480 )
6481 if from_linter:
6482 _ = self.process(delete_stmt.table, from_linter=from_linter)
6483
6484 crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
6485
6486 if delete_stmt._hints:
6487 dialect_hints, table_text = self._setup_crud_hints(
6488 delete_stmt, table_text
6489 )
6490 else:
6491 dialect_hints = None
6492
6493 if delete_stmt._independent_ctes:
6494 self._dispatch_independent_ctes(delete_stmt, kw)
6495
6496 text += table_text
6497
6498 if (
6499 self.implicit_returning or delete_stmt._returning
6500 ) and self.returning_precedes_values:
6501 text += " " + self.returning_clause(
6502 delete_stmt,
6503 self.implicit_returning or delete_stmt._returning,
6504 populate_result_map=toplevel,
6505 )
6506
6507 if extra_froms:
6508 extra_from_text = self.delete_extra_from_clause(
6509 delete_stmt,
6510 delete_stmt.table,
6511 extra_froms,
6512 dialect_hints,
6513 from_linter=from_linter,
6514 **kw,
6515 )
6516 if extra_from_text:
6517 text += " " + extra_from_text
6518
6519 if delete_stmt._where_criteria:
6520 t = self._generate_delimited_and_list(
6521 delete_stmt._where_criteria, from_linter=from_linter, **kw
6522 )
6523 if t:
6524 text += " WHERE " + t
6525
6526 limit_clause = self.delete_limit_clause(delete_stmt)
6527 if limit_clause:
6528 text += " " + limit_clause
6529
6530 if (
6531 self.implicit_returning or delete_stmt._returning
6532 ) and not self.returning_precedes_values:
6533 text += " " + self.returning_clause(
6534 delete_stmt,
6535 self.implicit_returning or delete_stmt._returning,
6536 populate_result_map=toplevel,
6537 )
6538
6539 if self.ctes:
6540 nesting_level = len(self.stack) if not toplevel else None
6541 text = self._render_cte_clause(nesting_level=nesting_level) + text
6542
6543 if warn_linting:
6544 assert from_linter is not None
6545 from_linter.warn(stmt_type="DELETE")
6546
6547 self.stack.pop(-1)
6548
6549 return text
6550
6551 def visit_savepoint(self, savepoint_stmt, **kw):
6552 return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
6553
6554 def visit_rollback_to_savepoint(self, savepoint_stmt, **kw):
6555 return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
6556 savepoint_stmt
6557 )
6558
6559 def visit_release_savepoint(self, savepoint_stmt, **kw):
6560 return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
6561 savepoint_stmt
6562 )
6563
6564
6565class StrSQLCompiler(SQLCompiler):
6566 """A :class:`.SQLCompiler` subclass which allows a small selection
6567 of non-standard SQL features to render into a string value.
6568
6569 The :class:`.StrSQLCompiler` is invoked whenever a Core expression
6570 element is directly stringified without calling upon the
6571 :meth:`_expression.ClauseElement.compile` method.
6572 It can render a limited set
6573 of non-standard SQL constructs to assist in basic stringification,
6574 however for more substantial custom or dialect-specific SQL constructs,
6575 it will be necessary to make use of
6576 :meth:`_expression.ClauseElement.compile`
6577 directly.
6578
6579 .. seealso::
6580
6581 :ref:`faq_sql_expression_string`
6582
6583 """
6584
6585 def _fallback_column_name(self, column):
6586 return "<name unknown>"
6587
6588 @util.preload_module("sqlalchemy.engine.url")
6589 def visit_unsupported_compilation(self, element, err, **kw):
6590 if element.stringify_dialect != "default":
6591 url = util.preloaded.engine_url
6592 dialect = url.URL.create(element.stringify_dialect).get_dialect()()
6593
6594 compiler = dialect.statement_compiler(
6595 dialect, None, _supporting_against=self
6596 )
6597 if not isinstance(compiler, StrSQLCompiler):
6598 return compiler.process(element, **kw)
6599
6600 return super().visit_unsupported_compilation(element, err)
6601
6602 def visit_getitem_binary(self, binary, operator, **kw):
6603 return "%s[%s]" % (
6604 self.process(binary.left, **kw),
6605 self.process(binary.right, **kw),
6606 )
6607
6608 def visit_json_getitem_op_binary(self, binary, operator, **kw):
6609 return self.visit_getitem_binary(binary, operator, **kw)
6610
6611 def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
6612 return self.visit_getitem_binary(binary, operator, **kw)
6613
6614 def visit_sequence(self, sequence, **kw):
6615 return (
6616 f"<next sequence value: {self.preparer.format_sequence(sequence)}>"
6617 )
6618
6619 def returning_clause(
6620 self,
6621 stmt: UpdateBase,
6622 returning_cols: Sequence[_ColumnsClauseElement],
6623 *,
6624 populate_result_map: bool,
6625 **kw: Any,
6626 ) -> str:
6627 columns = [
6628 self._label_select_column(None, c, True, False, {})
6629 for c in base._select_iterables(returning_cols)
6630 ]
6631 return "RETURNING " + ", ".join(columns)
6632
6633 def update_from_clause(
6634 self, update_stmt, from_table, extra_froms, from_hints, **kw
6635 ):
6636 kw["asfrom"] = True
6637 return "FROM " + ", ".join(
6638 t._compiler_dispatch(self, fromhints=from_hints, **kw)
6639 for t in extra_froms
6640 )
6641
6642 def delete_extra_from_clause(
6643 self, delete_stmt, from_table, extra_froms, from_hints, **kw
6644 ):
6645 kw["asfrom"] = True
6646 return ", " + ", ".join(
6647 t._compiler_dispatch(self, fromhints=from_hints, **kw)
6648 for t in extra_froms
6649 )
6650
6651 def visit_empty_set_expr(self, element_types, **kw):
6652 return "SELECT 1 WHERE 1!=1"
6653
6654 def get_from_hint_text(self, table, text):
6655 return "[%s]" % text
6656
6657 def visit_regexp_match_op_binary(self, binary, operator, **kw):
6658 return self._generate_generic_binary(binary, " <regexp> ", **kw)
6659
6660 def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
6661 return self._generate_generic_binary(binary, " <not regexp> ", **kw)
6662
6663 def visit_regexp_replace_op_binary(self, binary, operator, **kw):
6664 return "<regexp replace>(%s, %s)" % (
6665 binary.left._compiler_dispatch(self, **kw),
6666 binary.right._compiler_dispatch(self, **kw),
6667 )
6668
6669 def visit_try_cast(self, cast, **kwargs):
6670 return "TRY_CAST(%s AS %s)" % (
6671 cast.clause._compiler_dispatch(self, **kwargs),
6672 cast.typeclause._compiler_dispatch(self, **kwargs),
6673 )
6674
6675
6676class DDLCompiler(Compiled):
6677 is_ddl = True
6678
6679 if TYPE_CHECKING:
6680
6681 def __init__(
6682 self,
6683 dialect: Dialect,
6684 statement: ExecutableDDLElement,
6685 schema_translate_map: Optional[SchemaTranslateMapType] = ...,
6686 render_schema_translate: bool = ...,
6687 compile_kwargs: Mapping[str, Any] = ...,
6688 ): ...
6689
6690 @util.ro_memoized_property
6691 def sql_compiler(self) -> SQLCompiler:
6692 return self.dialect.statement_compiler(
6693 self.dialect, None, schema_translate_map=self.schema_translate_map
6694 )
6695
6696 @util.memoized_property
6697 def type_compiler(self):
6698 return self.dialect.type_compiler_instance
6699
6700 def construct_params(
6701 self,
6702 params: Optional[_CoreSingleExecuteParams] = None,
6703 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
6704 escape_names: bool = True,
6705 ) -> Optional[_MutableCoreSingleExecuteParams]:
6706 return None
6707
6708 def visit_ddl(self, ddl, **kwargs):
6709 # table events can substitute table and schema name
6710 context = ddl.context
6711 if isinstance(ddl.target, schema.Table):
6712 context = context.copy()
6713
6714 preparer = self.preparer
6715 path = preparer.format_table_seq(ddl.target)
6716 if len(path) == 1:
6717 table, sch = path[0], ""
6718 else:
6719 table, sch = path[-1], path[0]
6720
6721 context.setdefault("table", table)
6722 context.setdefault("schema", sch)
6723 context.setdefault("fullname", preparer.format_table(ddl.target))
6724
6725 return self.sql_compiler.post_process_text(ddl.statement % context)
6726
6727 def visit_create_schema(self, create, **kw):
6728 text = "CREATE SCHEMA "
6729 if create.if_not_exists:
6730 text += "IF NOT EXISTS "
6731 return text + self.preparer.format_schema(create.element)
6732
6733 def visit_drop_schema(self, drop, **kw):
6734 text = "DROP SCHEMA "
6735 if drop.if_exists:
6736 text += "IF EXISTS "
6737 text += self.preparer.format_schema(drop.element)
6738 if drop.cascade:
6739 text += " CASCADE"
6740 return text
6741
6742 def visit_create_table(self, create, **kw):
6743 table = create.element
6744 preparer = self.preparer
6745
6746 text = "\nCREATE "
6747 if table._prefixes:
6748 text += " ".join(table._prefixes) + " "
6749
6750 text += "TABLE "
6751 if create.if_not_exists:
6752 text += "IF NOT EXISTS "
6753
6754 text += preparer.format_table(table) + " "
6755
6756 create_table_suffix = self.create_table_suffix(table)
6757 if create_table_suffix:
6758 text += create_table_suffix + " "
6759
6760 text += "("
6761
6762 separator = "\n"
6763
6764 # if only one primary key, specify it along with the column
6765 first_pk = False
6766 for create_column in create.columns:
6767 column = create_column.element
6768 try:
6769 processed = self.process(
6770 create_column, first_pk=column.primary_key and not first_pk
6771 )
6772 if processed is not None:
6773 text += separator
6774 separator = ", \n"
6775 text += "\t" + processed
6776 if column.primary_key:
6777 first_pk = True
6778 except exc.CompileError as ce:
6779 raise exc.CompileError(
6780 "(in table '%s', column '%s'): %s"
6781 % (table.description, column.name, ce.args[0])
6782 ) from ce
6783
6784 const = self.create_table_constraints(
6785 table,
6786 _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
6787 )
6788 if const:
6789 text += separator + "\t" + const
6790
6791 text += "\n)%s\n\n" % self.post_create_table(table)
6792 return text
6793
6794 def visit_create_column(self, create, first_pk=False, **kw):
6795 column = create.element
6796
6797 if column.system:
6798 return None
6799
6800 text = self.get_column_specification(column, first_pk=first_pk)
6801 const = " ".join(
6802 self.process(constraint) for constraint in column.constraints
6803 )
6804 if const:
6805 text += " " + const
6806
6807 return text
6808
6809 def create_table_constraints(
6810 self, table, _include_foreign_key_constraints=None, **kw
6811 ):
6812 # On some DB order is significant: visit PK first, then the
6813 # other constraints (engine.ReflectionTest.testbasic failed on FB2)
6814 constraints = []
6815 if table.primary_key:
6816 constraints.append(table.primary_key)
6817
6818 all_fkcs = table.foreign_key_constraints
6819 if _include_foreign_key_constraints is not None:
6820 omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints)
6821 else:
6822 omit_fkcs = set()
6823
6824 constraints.extend(
6825 [
6826 c
6827 for c in table._sorted_constraints
6828 if c is not table.primary_key and c not in omit_fkcs
6829 ]
6830 )
6831
6832 return ", \n\t".join(
6833 p
6834 for p in (
6835 self.process(constraint)
6836 for constraint in constraints
6837 if (constraint._should_create_for_compiler(self))
6838 and (
6839 not self.dialect.supports_alter
6840 or not getattr(constraint, "use_alter", False)
6841 )
6842 )
6843 if p is not None
6844 )
6845
6846 def visit_drop_table(self, drop, **kw):
6847 text = "\nDROP TABLE "
6848 if drop.if_exists:
6849 text += "IF EXISTS "
6850 return text + self.preparer.format_table(drop.element)
6851
6852 def visit_drop_view(self, drop, **kw):
6853 return "\nDROP VIEW " + self.preparer.format_table(drop.element)
6854
6855 def _verify_index_table(self, index: Index) -> None:
6856 if index.table is None:
6857 raise exc.CompileError(
6858 "Index '%s' is not associated with any table." % index.name
6859 )
6860
6861 def visit_create_index(
6862 self, create, include_schema=False, include_table_schema=True, **kw
6863 ):
6864 index = create.element
6865 self._verify_index_table(index)
6866 preparer = self.preparer
6867 text = "CREATE "
6868 if index.unique:
6869 text += "UNIQUE "
6870 if index.name is None:
6871 raise exc.CompileError(
6872 "CREATE INDEX requires that the index have a name"
6873 )
6874
6875 text += "INDEX "
6876 if create.if_not_exists:
6877 text += "IF NOT EXISTS "
6878
6879 text += "%s ON %s (%s)" % (
6880 self._prepared_index_name(index, include_schema=include_schema),
6881 preparer.format_table(
6882 index.table, use_schema=include_table_schema
6883 ),
6884 ", ".join(
6885 self.sql_compiler.process(
6886 expr, include_table=False, literal_binds=True
6887 )
6888 for expr in index.expressions
6889 ),
6890 )
6891 return text
6892
6893 def visit_drop_index(self, drop, **kw):
6894 index = drop.element
6895
6896 if index.name is None:
6897 raise exc.CompileError(
6898 "DROP INDEX requires that the index have a name"
6899 )
6900 text = "\nDROP INDEX "
6901 if drop.if_exists:
6902 text += "IF EXISTS "
6903
6904 return text + self._prepared_index_name(index, include_schema=True)
6905
6906 def _prepared_index_name(
6907 self, index: Index, include_schema: bool = False
6908 ) -> str:
6909 if index.table is not None:
6910 effective_schema = self.preparer.schema_for_object(index.table)
6911 else:
6912 effective_schema = None
6913 if include_schema and effective_schema:
6914 schema_name = self.preparer.quote_schema(effective_schema)
6915 else:
6916 schema_name = None
6917
6918 index_name: str = self.preparer.format_index(index)
6919
6920 if schema_name:
6921 index_name = schema_name + "." + index_name
6922 return index_name
6923
6924 def visit_add_constraint(self, create, **kw):
6925 return "ALTER TABLE %s ADD %s" % (
6926 self.preparer.format_table(create.element.table),
6927 self.process(create.element),
6928 )
6929
6930 def visit_set_table_comment(self, create, **kw):
6931 return "COMMENT ON TABLE %s IS %s" % (
6932 self.preparer.format_table(create.element),
6933 self.sql_compiler.render_literal_value(
6934 create.element.comment, sqltypes.String()
6935 ),
6936 )
6937
6938 def visit_drop_table_comment(self, drop, **kw):
6939 return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
6940 drop.element
6941 )
6942
6943 def visit_set_column_comment(self, create, **kw):
6944 return "COMMENT ON COLUMN %s IS %s" % (
6945 self.preparer.format_column(
6946 create.element, use_table=True, use_schema=True
6947 ),
6948 self.sql_compiler.render_literal_value(
6949 create.element.comment, sqltypes.String()
6950 ),
6951 )
6952
6953 def visit_drop_column_comment(self, drop, **kw):
6954 return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
6955 drop.element, use_table=True
6956 )
6957
6958 def visit_set_constraint_comment(self, create, **kw):
6959 raise exc.UnsupportedCompilationError(self, type(create))
6960
6961 def visit_drop_constraint_comment(self, drop, **kw):
6962 raise exc.UnsupportedCompilationError(self, type(drop))
6963
6964 def get_identity_options(self, identity_options):
6965 text = []
6966 if identity_options.increment is not None:
6967 text.append("INCREMENT BY %d" % identity_options.increment)
6968 if identity_options.start is not None:
6969 text.append("START WITH %d" % identity_options.start)
6970 if identity_options.minvalue is not None:
6971 text.append("MINVALUE %d" % identity_options.minvalue)
6972 if identity_options.maxvalue is not None:
6973 text.append("MAXVALUE %d" % identity_options.maxvalue)
6974 if identity_options.nominvalue is not None:
6975 text.append("NO MINVALUE")
6976 if identity_options.nomaxvalue is not None:
6977 text.append("NO MAXVALUE")
6978 if identity_options.cache is not None:
6979 text.append("CACHE %d" % identity_options.cache)
6980 if identity_options.cycle is not None:
6981 text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
6982 return " ".join(text)
6983
6984 def visit_create_sequence(self, create, prefix=None, **kw):
6985 text = "CREATE SEQUENCE "
6986 if create.if_not_exists:
6987 text += "IF NOT EXISTS "
6988 text += self.preparer.format_sequence(create.element)
6989
6990 if prefix:
6991 text += prefix
6992 options = self.get_identity_options(create.element)
6993 if options:
6994 text += " " + options
6995 return text
6996
6997 def visit_drop_sequence(self, drop, **kw):
6998 text = "DROP SEQUENCE "
6999 if drop.if_exists:
7000 text += "IF EXISTS "
7001 return text + self.preparer.format_sequence(drop.element)
7002
7003 def visit_drop_constraint(self, drop, **kw):
7004 constraint = drop.element
7005 if constraint.name is not None:
7006 formatted_name = self.preparer.format_constraint(constraint)
7007 else:
7008 formatted_name = None
7009
7010 if formatted_name is None:
7011 raise exc.CompileError(
7012 "Can't emit DROP CONSTRAINT for constraint %r; "
7013 "it has no name" % drop.element
7014 )
7015 return "ALTER TABLE %s DROP CONSTRAINT %s%s%s" % (
7016 self.preparer.format_table(drop.element.table),
7017 "IF EXISTS " if drop.if_exists else "",
7018 formatted_name,
7019 " CASCADE" if drop.cascade else "",
7020 )
7021
7022 def get_column_specification(self, column, **kwargs):
7023 colspec = (
7024 self.preparer.format_column(column)
7025 + " "
7026 + self.dialect.type_compiler_instance.process(
7027 column.type, type_expression=column
7028 )
7029 )
7030 default = self.get_column_default_string(column)
7031 if default is not None:
7032 colspec += " DEFAULT " + default
7033
7034 if column.computed is not None:
7035 colspec += " " + self.process(column.computed)
7036
7037 if (
7038 column.identity is not None
7039 and self.dialect.supports_identity_columns
7040 ):
7041 colspec += " " + self.process(column.identity)
7042
7043 if not column.nullable and (
7044 not column.identity or not self.dialect.supports_identity_columns
7045 ):
7046 colspec += " NOT NULL"
7047 return colspec
7048
7049 def create_table_suffix(self, table):
7050 return ""
7051
7052 def post_create_table(self, table):
7053 return ""
7054
7055 def get_column_default_string(self, column: Column[Any]) -> Optional[str]:
7056 if isinstance(column.server_default, schema.DefaultClause):
7057 return self.render_default_string(column.server_default.arg)
7058 else:
7059 return None
7060
7061 def render_default_string(self, default: Union[Visitable, str]) -> str:
7062 if isinstance(default, str):
7063 return self.sql_compiler.render_literal_value(
7064 default, sqltypes.STRINGTYPE
7065 )
7066 else:
7067 return self.sql_compiler.process(default, literal_binds=True)
7068
7069 def visit_table_or_column_check_constraint(self, constraint, **kw):
7070 if constraint.is_column_level:
7071 return self.visit_column_check_constraint(constraint)
7072 else:
7073 return self.visit_check_constraint(constraint)
7074
7075 def visit_check_constraint(self, constraint, **kw):
7076 text = ""
7077 if constraint.name is not None:
7078 formatted_name = self.preparer.format_constraint(constraint)
7079 if formatted_name is not None:
7080 text += "CONSTRAINT %s " % formatted_name
7081 text += "CHECK (%s)" % self.sql_compiler.process(
7082 constraint.sqltext, include_table=False, literal_binds=True
7083 )
7084 text += self.define_constraint_deferrability(constraint)
7085 return text
7086
7087 def visit_column_check_constraint(self, constraint, **kw):
7088 text = ""
7089 if constraint.name is not None:
7090 formatted_name = self.preparer.format_constraint(constraint)
7091 if formatted_name is not None:
7092 text += "CONSTRAINT %s " % formatted_name
7093 text += "CHECK (%s)" % self.sql_compiler.process(
7094 constraint.sqltext, include_table=False, literal_binds=True
7095 )
7096 text += self.define_constraint_deferrability(constraint)
7097 return text
7098
7099 def visit_primary_key_constraint(
7100 self, constraint: PrimaryKeyConstraint, **kw: Any
7101 ) -> str:
7102 if len(constraint) == 0:
7103 return ""
7104 text = ""
7105 if constraint.name is not None:
7106 formatted_name = self.preparer.format_constraint(constraint)
7107 if formatted_name is not None:
7108 text += "CONSTRAINT %s " % formatted_name
7109 text += "PRIMARY KEY "
7110 text += "(%s)" % ", ".join(
7111 self.preparer.quote(c.name)
7112 for c in (
7113 constraint.columns_autoinc_first
7114 if constraint._implicit_generated
7115 else constraint.columns
7116 )
7117 )
7118 text += self.define_constraint_deferrability(constraint)
7119 return text
7120
7121 def visit_foreign_key_constraint(self, constraint, **kw):
7122 preparer = self.preparer
7123 text = ""
7124 if constraint.name is not None:
7125 formatted_name = self.preparer.format_constraint(constraint)
7126 if formatted_name is not None:
7127 text += "CONSTRAINT %s " % formatted_name
7128 remote_table = list(constraint.elements)[0].column.table
7129 text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
7130 ", ".join(
7131 preparer.quote(f.parent.name) for f in constraint.elements
7132 ),
7133 self.define_constraint_remote_table(
7134 constraint, remote_table, preparer
7135 ),
7136 ", ".join(
7137 preparer.quote(f.column.name) for f in constraint.elements
7138 ),
7139 )
7140 text += self.define_constraint_match(constraint)
7141 text += self.define_constraint_cascades(constraint)
7142 text += self.define_constraint_deferrability(constraint)
7143 return text
7144
7145 def define_constraint_remote_table(self, constraint, table, preparer):
7146 """Format the remote table clause of a CREATE CONSTRAINT clause."""
7147
7148 return preparer.format_table(table)
7149
7150 def visit_unique_constraint(
7151 self, constraint: UniqueConstraint, **kw: Any
7152 ) -> str:
7153 if len(constraint) == 0:
7154 return ""
7155 text = ""
7156 if constraint.name is not None:
7157 formatted_name = self.preparer.format_constraint(constraint)
7158 if formatted_name is not None:
7159 text += "CONSTRAINT %s " % formatted_name
7160 text += "UNIQUE %s(%s)" % (
7161 self.define_unique_constraint_distinct(constraint, **kw),
7162 ", ".join(self.preparer.quote(c.name) for c in constraint),
7163 )
7164 text += self.define_constraint_deferrability(constraint)
7165 return text
7166
7167 def define_unique_constraint_distinct(
7168 self, constraint: UniqueConstraint, **kw: Any
7169 ) -> str:
7170 return ""
7171
7172 def define_constraint_cascades(
7173 self, constraint: ForeignKeyConstraint
7174 ) -> str:
7175 text = ""
7176 if constraint.ondelete is not None:
7177 text += self.define_constraint_ondelete_cascade(constraint)
7178
7179 if constraint.onupdate is not None:
7180 text += self.define_constraint_onupdate_cascade(constraint)
7181 return text
7182
7183 def define_constraint_ondelete_cascade(
7184 self, constraint: ForeignKeyConstraint
7185 ) -> str:
7186 return " ON DELETE %s" % self.preparer.validate_sql_phrase(
7187 constraint.ondelete, FK_ON_DELETE
7188 )
7189
7190 def define_constraint_onupdate_cascade(
7191 self, constraint: ForeignKeyConstraint
7192 ) -> str:
7193 return " ON UPDATE %s" % self.preparer.validate_sql_phrase(
7194 constraint.onupdate, FK_ON_UPDATE
7195 )
7196
7197 def define_constraint_deferrability(self, constraint: Constraint) -> str:
7198 text = ""
7199 if constraint.deferrable is not None:
7200 if constraint.deferrable:
7201 text += " DEFERRABLE"
7202 else:
7203 text += " NOT DEFERRABLE"
7204 if constraint.initially is not None:
7205 text += " INITIALLY %s" % self.preparer.validate_sql_phrase(
7206 constraint.initially, FK_INITIALLY
7207 )
7208 return text
7209
7210 def define_constraint_match(self, constraint):
7211 text = ""
7212 if constraint.match is not None:
7213 text += " MATCH %s" % constraint.match
7214 return text
7215
7216 def visit_computed_column(self, generated, **kw):
7217 text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
7218 generated.sqltext, include_table=False, literal_binds=True
7219 )
7220 if generated.persisted is True:
7221 text += " STORED"
7222 elif generated.persisted is False:
7223 text += " VIRTUAL"
7224 return text
7225
7226 def visit_identity_column(self, identity, **kw):
7227 text = "GENERATED %s AS IDENTITY" % (
7228 "ALWAYS" if identity.always else "BY DEFAULT",
7229 )
7230 options = self.get_identity_options(identity)
7231 if options:
7232 text += " (%s)" % options
7233 return text
7234
7235
7236class GenericTypeCompiler(TypeCompiler):
7237 def visit_FLOAT(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
7238 return "FLOAT"
7239
7240 def visit_DOUBLE(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
7241 return "DOUBLE"
7242
7243 def visit_DOUBLE_PRECISION(
7244 self, type_: sqltypes.DOUBLE_PRECISION[Any], **kw: Any
7245 ) -> str:
7246 return "DOUBLE PRECISION"
7247
7248 def visit_REAL(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
7249 return "REAL"
7250
7251 def visit_NUMERIC(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
7252 if type_.precision is None:
7253 return "NUMERIC"
7254 elif type_.scale is None:
7255 return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
7256 else:
7257 return "NUMERIC(%(precision)s, %(scale)s)" % {
7258 "precision": type_.precision,
7259 "scale": type_.scale,
7260 }
7261
7262 def visit_DECIMAL(self, type_: sqltypes.DECIMAL[Any], **kw: Any) -> str:
7263 if type_.precision is None:
7264 return "DECIMAL"
7265 elif type_.scale is None:
7266 return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
7267 else:
7268 return "DECIMAL(%(precision)s, %(scale)s)" % {
7269 "precision": type_.precision,
7270 "scale": type_.scale,
7271 }
7272
7273 def visit_INTEGER(self, type_: sqltypes.Integer, **kw: Any) -> str:
7274 return "INTEGER"
7275
7276 def visit_SMALLINT(self, type_: sqltypes.SmallInteger, **kw: Any) -> str:
7277 return "SMALLINT"
7278
7279 def visit_BIGINT(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
7280 return "BIGINT"
7281
7282 def visit_TIMESTAMP(self, type_: sqltypes.TIMESTAMP, **kw: Any) -> str:
7283 return "TIMESTAMP"
7284
7285 def visit_DATETIME(self, type_: sqltypes.DateTime, **kw: Any) -> str:
7286 return "DATETIME"
7287
7288 def visit_DATE(self, type_: sqltypes.Date, **kw: Any) -> str:
7289 return "DATE"
7290
7291 def visit_TIME(self, type_: sqltypes.Time, **kw: Any) -> str:
7292 return "TIME"
7293
7294 def visit_CLOB(self, type_: sqltypes.CLOB, **kw: Any) -> str:
7295 return "CLOB"
7296
7297 def visit_NCLOB(self, type_: sqltypes.Text, **kw: Any) -> str:
7298 return "NCLOB"
7299
7300 def _render_string_type(
7301 self, name: str, length: Optional[int], collation: Optional[str]
7302 ) -> str:
7303 text = name
7304 if length:
7305 text += f"({length})"
7306 if collation:
7307 text += f' COLLATE "{collation}"'
7308 return text
7309
7310 def visit_CHAR(self, type_: sqltypes.CHAR, **kw: Any) -> str:
7311 return self._render_string_type("CHAR", type_.length, type_.collation)
7312
7313 def visit_NCHAR(self, type_: sqltypes.NCHAR, **kw: Any) -> str:
7314 return self._render_string_type("NCHAR", type_.length, type_.collation)
7315
7316 def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
7317 return self._render_string_type(
7318 "VARCHAR", type_.length, type_.collation
7319 )
7320
7321 def visit_NVARCHAR(self, type_: sqltypes.NVARCHAR, **kw: Any) -> str:
7322 return self._render_string_type(
7323 "NVARCHAR", type_.length, type_.collation
7324 )
7325
7326 def visit_TEXT(self, type_: sqltypes.Text, **kw: Any) -> str:
7327 return self._render_string_type("TEXT", type_.length, type_.collation)
7328
7329 def visit_UUID(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
7330 return "UUID"
7331
7332 def visit_BLOB(self, type_: sqltypes.LargeBinary, **kw: Any) -> str:
7333 return "BLOB"
7334
7335 def visit_BINARY(self, type_: sqltypes.BINARY, **kw: Any) -> str:
7336 return "BINARY" + (type_.length and "(%d)" % type_.length or "")
7337
7338 def visit_VARBINARY(self, type_: sqltypes.VARBINARY, **kw: Any) -> str:
7339 return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
7340
7341 def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str:
7342 return "BOOLEAN"
7343
7344 def visit_uuid(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
7345 if not type_.native_uuid or not self.dialect.supports_native_uuid:
7346 return self._render_string_type("CHAR", length=32, collation=None)
7347 else:
7348 return self.visit_UUID(type_, **kw)
7349
7350 def visit_large_binary(
7351 self, type_: sqltypes.LargeBinary, **kw: Any
7352 ) -> str:
7353 return self.visit_BLOB(type_, **kw)
7354
7355 def visit_boolean(self, type_: sqltypes.Boolean, **kw: Any) -> str:
7356 return self.visit_BOOLEAN(type_, **kw)
7357
7358 def visit_time(self, type_: sqltypes.Time, **kw: Any) -> str:
7359 return self.visit_TIME(type_, **kw)
7360
7361 def visit_datetime(self, type_: sqltypes.DateTime, **kw: Any) -> str:
7362 return self.visit_DATETIME(type_, **kw)
7363
7364 def visit_date(self, type_: sqltypes.Date, **kw: Any) -> str:
7365 return self.visit_DATE(type_, **kw)
7366
7367 def visit_big_integer(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
7368 return self.visit_BIGINT(type_, **kw)
7369
7370 def visit_small_integer(
7371 self, type_: sqltypes.SmallInteger, **kw: Any
7372 ) -> str:
7373 return self.visit_SMALLINT(type_, **kw)
7374
7375 def visit_integer(self, type_: sqltypes.Integer, **kw: Any) -> str:
7376 return self.visit_INTEGER(type_, **kw)
7377
7378 def visit_real(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
7379 return self.visit_REAL(type_, **kw)
7380
7381 def visit_float(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
7382 return self.visit_FLOAT(type_, **kw)
7383
7384 def visit_double(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
7385 return self.visit_DOUBLE(type_, **kw)
7386
7387 def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
7388 return self.visit_NUMERIC(type_, **kw)
7389
7390 def visit_string(self, type_: sqltypes.String, **kw: Any) -> str:
7391 return self.visit_VARCHAR(type_, **kw)
7392
7393 def visit_unicode(self, type_: sqltypes.Unicode, **kw: Any) -> str:
7394 return self.visit_VARCHAR(type_, **kw)
7395
7396 def visit_text(self, type_: sqltypes.Text, **kw: Any) -> str:
7397 return self.visit_TEXT(type_, **kw)
7398
7399 def visit_unicode_text(
7400 self, type_: sqltypes.UnicodeText, **kw: Any
7401 ) -> str:
7402 return self.visit_TEXT(type_, **kw)
7403
7404 def visit_enum(self, type_: sqltypes.Enum, **kw: Any) -> str:
7405 return self.visit_VARCHAR(type_, **kw)
7406
7407 def visit_null(self, type_, **kw):
7408 raise exc.CompileError(
7409 "Can't generate DDL for %r; "
7410 "did you forget to specify a "
7411 "type on this Column?" % type_
7412 )
7413
7414 def visit_type_decorator(
7415 self, type_: TypeDecorator[Any], **kw: Any
7416 ) -> str:
7417 return self.process(type_.type_engine(self.dialect), **kw)
7418
7419 def visit_user_defined(
7420 self, type_: UserDefinedType[Any], **kw: Any
7421 ) -> str:
7422 return type_.get_col_spec(**kw)
7423
7424
7425class StrSQLTypeCompiler(GenericTypeCompiler):
7426 def process(self, type_, **kw):
7427 try:
7428 _compiler_dispatch = type_._compiler_dispatch
7429 except AttributeError:
7430 return self._visit_unknown(type_, **kw)
7431 else:
7432 return _compiler_dispatch(self, **kw)
7433
7434 def __getattr__(self, key):
7435 if key.startswith("visit_"):
7436 return self._visit_unknown
7437 else:
7438 raise AttributeError(key)
7439
7440 def _visit_unknown(self, type_, **kw):
7441 if type_.__class__.__name__ == type_.__class__.__name__.upper():
7442 return type_.__class__.__name__
7443 else:
7444 return repr(type_)
7445
7446 def visit_null(self, type_, **kw):
7447 return "NULL"
7448
7449 def visit_user_defined(self, type_, **kw):
7450 try:
7451 get_col_spec = type_.get_col_spec
7452 except AttributeError:
7453 return repr(type_)
7454 else:
7455 return get_col_spec(**kw)
7456
7457
7458class _SchemaForObjectCallable(Protocol):
7459 def __call__(self, __obj: Any) -> str: ...
7460
7461
7462class _BindNameForColProtocol(Protocol):
7463 def __call__(self, col: ColumnClause[Any]) -> str: ...
7464
7465
7466class IdentifierPreparer:
7467 """Handle quoting and case-folding of identifiers based on options."""
7468
7469 reserved_words = RESERVED_WORDS
7470
7471 legal_characters = LEGAL_CHARACTERS
7472
7473 illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
7474
7475 initial_quote: str
7476
7477 final_quote: str
7478
7479 _strings: MutableMapping[str, str]
7480
7481 schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema")
7482 """Return the .schema attribute for an object.
7483
7484 For the default IdentifierPreparer, the schema for an object is always
7485 the value of the ".schema" attribute. if the preparer is replaced
7486 with one that has a non-empty schema_translate_map, the value of the
7487 ".schema" attribute is rendered a symbol that will be converted to a
7488 real schema name from the mapping post-compile.
7489
7490 """
7491
7492 _includes_none_schema_translate: bool = False
7493
7494 def __init__(
7495 self,
7496 dialect: Dialect,
7497 initial_quote: str = '"',
7498 final_quote: Optional[str] = None,
7499 escape_quote: str = '"',
7500 quote_case_sensitive_collations: bool = True,
7501 omit_schema: bool = False,
7502 ):
7503 """Construct a new ``IdentifierPreparer`` object.
7504
7505 initial_quote
7506 Character that begins a delimited identifier.
7507
7508 final_quote
7509 Character that ends a delimited identifier. Defaults to
7510 `initial_quote`.
7511
7512 omit_schema
7513 Prevent prepending schema name. Useful for databases that do
7514 not support schemae.
7515 """
7516
7517 self.dialect = dialect
7518 self.initial_quote = initial_quote
7519 self.final_quote = final_quote or self.initial_quote
7520 self.escape_quote = escape_quote
7521 self.escape_to_quote = self.escape_quote * 2
7522 self.omit_schema = omit_schema
7523 self.quote_case_sensitive_collations = quote_case_sensitive_collations
7524 self._strings = {}
7525 self._double_percents = self.dialect.paramstyle in (
7526 "format",
7527 "pyformat",
7528 )
7529
7530 def _with_schema_translate(self, schema_translate_map):
7531 prep = self.__class__.__new__(self.__class__)
7532 prep.__dict__.update(self.__dict__)
7533
7534 includes_none = None in schema_translate_map
7535
7536 def symbol_getter(obj):
7537 name = obj.schema
7538 if obj._use_schema_map and (name is not None or includes_none):
7539 if name is not None and ("[" in name or "]" in name):
7540 raise exc.CompileError(
7541 "Square bracket characters ([]) not supported "
7542 "in schema translate name '%s'" % name
7543 )
7544 return quoted_name(
7545 "__[SCHEMA_%s]" % (name or "_none"), quote=False
7546 )
7547 else:
7548 return obj.schema
7549
7550 prep.schema_for_object = symbol_getter
7551 prep._includes_none_schema_translate = includes_none
7552 return prep
7553
7554 def _render_schema_translates(
7555 self, statement: str, schema_translate_map: SchemaTranslateMapType
7556 ) -> str:
7557 d = schema_translate_map
7558 if None in d:
7559 if not self._includes_none_schema_translate:
7560 raise exc.InvalidRequestError(
7561 "schema translate map which previously did not have "
7562 "`None` present as a key now has `None` present; compiled "
7563 "statement may lack adequate placeholders. Please use "
7564 "consistent keys in successive "
7565 "schema_translate_map dictionaries."
7566 )
7567
7568 d["_none"] = d[None] # type: ignore[index]
7569
7570 def replace(m):
7571 name = m.group(2)
7572 if name in d:
7573 effective_schema = d[name]
7574 else:
7575 if name in (None, "_none"):
7576 raise exc.InvalidRequestError(
7577 "schema translate map which previously had `None` "
7578 "present as a key now no longer has it present; don't "
7579 "know how to apply schema for compiled statement. "
7580 "Please use consistent keys in successive "
7581 "schema_translate_map dictionaries."
7582 )
7583 effective_schema = name
7584
7585 if not effective_schema:
7586 effective_schema = self.dialect.default_schema_name
7587 if not effective_schema:
7588 # TODO: no coverage here
7589 raise exc.CompileError(
7590 "Dialect has no default schema name; can't "
7591 "use None as dynamic schema target."
7592 )
7593 return self.quote_schema(effective_schema)
7594
7595 return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
7596
7597 def _escape_identifier(self, value: str) -> str:
7598 """Escape an identifier.
7599
7600 Subclasses should override this to provide database-dependent
7601 escaping behavior.
7602 """
7603
7604 value = value.replace(self.escape_quote, self.escape_to_quote)
7605 if self._double_percents:
7606 value = value.replace("%", "%%")
7607 return value
7608
7609 def _unescape_identifier(self, value: str) -> str:
7610 """Canonicalize an escaped identifier.
7611
7612 Subclasses should override this to provide database-dependent
7613 unescaping behavior that reverses _escape_identifier.
7614 """
7615
7616 return value.replace(self.escape_to_quote, self.escape_quote)
7617
7618 def validate_sql_phrase(self, element, reg):
7619 """keyword sequence filter.
7620
7621 a filter for elements that are intended to represent keyword sequences,
7622 such as "INITIALLY", "INITIALLY DEFERRED", etc. no special characters
7623 should be present.
7624
7625 .. versionadded:: 1.3
7626
7627 """
7628
7629 if element is not None and not reg.match(element):
7630 raise exc.CompileError(
7631 "Unexpected SQL phrase: %r (matching against %r)"
7632 % (element, reg.pattern)
7633 )
7634 return element
7635
7636 def quote_identifier(self, value: str) -> str:
7637 """Quote an identifier.
7638
7639 Subclasses should override this to provide database-dependent
7640 quoting behavior.
7641 """
7642
7643 return (
7644 self.initial_quote
7645 + self._escape_identifier(value)
7646 + self.final_quote
7647 )
7648
7649 def _requires_quotes(self, value: str) -> bool:
7650 """Return True if the given identifier requires quoting."""
7651 lc_value = value.lower()
7652 return (
7653 lc_value in self.reserved_words
7654 or value[0] in self.illegal_initial_characters
7655 or not self.legal_characters.match(str(value))
7656 or (lc_value != value)
7657 )
7658
7659 def _requires_quotes_illegal_chars(self, value):
7660 """Return True if the given identifier requires quoting, but
7661 not taking case convention into account."""
7662 return not self.legal_characters.match(str(value))
7663
7664 def quote_schema(self, schema: str, force: Any = None) -> str:
7665 """Conditionally quote a schema name.
7666
7667
7668 The name is quoted if it is a reserved word, contains quote-necessary
7669 characters, or is an instance of :class:`.quoted_name` which includes
7670 ``quote`` set to ``True``.
7671
7672 Subclasses can override this to provide database-dependent
7673 quoting behavior for schema names.
7674
7675 :param schema: string schema name
7676 :param force: unused
7677
7678 .. deprecated:: 0.9
7679
7680 The :paramref:`.IdentifierPreparer.quote_schema.force`
7681 parameter is deprecated and will be removed in a future
7682 release. This flag has no effect on the behavior of the
7683 :meth:`.IdentifierPreparer.quote` method; please refer to
7684 :class:`.quoted_name`.
7685
7686 """
7687 if force is not None:
7688 # not using the util.deprecated_params() decorator in this
7689 # case because of the additional function call overhead on this
7690 # very performance-critical spot.
7691 util.warn_deprecated(
7692 "The IdentifierPreparer.quote_schema.force parameter is "
7693 "deprecated and will be removed in a future release. This "
7694 "flag has no effect on the behavior of the "
7695 "IdentifierPreparer.quote method; please refer to "
7696 "quoted_name().",
7697 # deprecated 0.9. warning from 1.3
7698 version="0.9",
7699 )
7700
7701 return self.quote(schema)
7702
7703 def quote(self, ident: str, force: Any = None) -> str:
7704 """Conditionally quote an identifier.
7705
7706 The identifier is quoted if it is a reserved word, contains
7707 quote-necessary characters, or is an instance of
7708 :class:`.quoted_name` which includes ``quote`` set to ``True``.
7709
7710 Subclasses can override this to provide database-dependent
7711 quoting behavior for identifier names.
7712
7713 :param ident: string identifier
7714 :param force: unused
7715
7716 .. deprecated:: 0.9
7717
7718 The :paramref:`.IdentifierPreparer.quote.force`
7719 parameter is deprecated and will be removed in a future
7720 release. This flag has no effect on the behavior of the
7721 :meth:`.IdentifierPreparer.quote` method; please refer to
7722 :class:`.quoted_name`.
7723
7724 """
7725 if force is not None:
7726 # not using the util.deprecated_params() decorator in this
7727 # case because of the additional function call overhead on this
7728 # very performance-critical spot.
7729 util.warn_deprecated(
7730 "The IdentifierPreparer.quote.force parameter is "
7731 "deprecated and will be removed in a future release. This "
7732 "flag has no effect on the behavior of the "
7733 "IdentifierPreparer.quote method; please refer to "
7734 "quoted_name().",
7735 # deprecated 0.9. warning from 1.3
7736 version="0.9",
7737 )
7738
7739 force = getattr(ident, "quote", None)
7740
7741 if force is None:
7742 if ident in self._strings:
7743 return self._strings[ident]
7744 else:
7745 if self._requires_quotes(ident):
7746 self._strings[ident] = self.quote_identifier(ident)
7747 else:
7748 self._strings[ident] = ident
7749 return self._strings[ident]
7750 elif force:
7751 return self.quote_identifier(ident)
7752 else:
7753 return ident
7754
7755 def format_collation(self, collation_name):
7756 if self.quote_case_sensitive_collations:
7757 return self.quote(collation_name)
7758 else:
7759 return collation_name
7760
7761 def format_sequence(
7762 self, sequence: schema.Sequence, use_schema: bool = True
7763 ) -> str:
7764 name = self.quote(sequence.name)
7765
7766 effective_schema = self.schema_for_object(sequence)
7767
7768 if (
7769 not self.omit_schema
7770 and use_schema
7771 and effective_schema is not None
7772 ):
7773 name = self.quote_schema(effective_schema) + "." + name
7774 return name
7775
7776 def format_label(
7777 self, label: Label[Any], name: Optional[str] = None
7778 ) -> str:
7779 return self.quote(name or label.name)
7780
7781 def format_alias(
7782 self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None
7783 ) -> str:
7784 if name is None:
7785 assert alias is not None
7786 return self.quote(alias.name)
7787 else:
7788 return self.quote(name)
7789
7790 def format_savepoint(self, savepoint, name=None):
7791 # Running the savepoint name through quoting is unnecessary
7792 # for all known dialects. This is here to support potential
7793 # third party use cases
7794 ident = name or savepoint.ident
7795 if self._requires_quotes(ident):
7796 ident = self.quote_identifier(ident)
7797 return ident
7798
7799 @util.preload_module("sqlalchemy.sql.naming")
7800 def format_constraint(
7801 self, constraint: Union[Constraint, Index], _alembic_quote: bool = True
7802 ) -> Optional[str]:
7803 naming = util.preloaded.sql_naming
7804
7805 if constraint.name is _NONE_NAME:
7806 name = naming._constraint_name_for_table(
7807 constraint, constraint.table
7808 )
7809
7810 if name is None:
7811 return None
7812 else:
7813 name = constraint.name
7814
7815 assert name is not None
7816 if constraint.__visit_name__ == "index":
7817 return self.truncate_and_render_index_name(
7818 name, _alembic_quote=_alembic_quote
7819 )
7820 else:
7821 return self.truncate_and_render_constraint_name(
7822 name, _alembic_quote=_alembic_quote
7823 )
7824
7825 def truncate_and_render_index_name(
7826 self, name: str, _alembic_quote: bool = True
7827 ) -> str:
7828 # calculate these at format time so that ad-hoc changes
7829 # to dialect.max_identifier_length etc. can be reflected
7830 # as IdentifierPreparer is long lived
7831 max_ = (
7832 self.dialect.max_index_name_length
7833 or self.dialect.max_identifier_length
7834 )
7835 return self._truncate_and_render_maxlen_name(
7836 name, max_, _alembic_quote
7837 )
7838
7839 def truncate_and_render_constraint_name(
7840 self, name: str, _alembic_quote: bool = True
7841 ) -> str:
7842 # calculate these at format time so that ad-hoc changes
7843 # to dialect.max_identifier_length etc. can be reflected
7844 # as IdentifierPreparer is long lived
7845 max_ = (
7846 self.dialect.max_constraint_name_length
7847 or self.dialect.max_identifier_length
7848 )
7849 return self._truncate_and_render_maxlen_name(
7850 name, max_, _alembic_quote
7851 )
7852
7853 def _truncate_and_render_maxlen_name(
7854 self, name: str, max_: int, _alembic_quote: bool
7855 ) -> str:
7856 if isinstance(name, elements._truncated_label):
7857 if len(name) > max_:
7858 name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
7859 else:
7860 self.dialect.validate_identifier(name)
7861
7862 if not _alembic_quote:
7863 return name
7864 else:
7865 return self.quote(name)
7866
7867 def format_index(self, index: Index) -> str:
7868 name = self.format_constraint(index)
7869 assert name is not None
7870 return name
7871
7872 def format_table(
7873 self,
7874 table: FromClause,
7875 use_schema: bool = True,
7876 name: Optional[str] = None,
7877 ) -> str:
7878 """Prepare a quoted table and schema name."""
7879 if name is None:
7880 if TYPE_CHECKING:
7881 assert isinstance(table, NamedFromClause)
7882 name = table.name
7883
7884 result = self.quote(name)
7885
7886 effective_schema = self.schema_for_object(table)
7887
7888 if not self.omit_schema and use_schema and effective_schema:
7889 result = self.quote_schema(effective_schema) + "." + result
7890 return result
7891
7892 def format_schema(self, name):
7893 """Prepare a quoted schema name."""
7894
7895 return self.quote(name)
7896
7897 def format_label_name(
7898 self,
7899 name,
7900 anon_map=None,
7901 ):
7902 """Prepare a quoted column name."""
7903
7904 if anon_map is not None and isinstance(
7905 name, elements._truncated_label
7906 ):
7907 name = name.apply_map(anon_map)
7908
7909 return self.quote(name)
7910
7911 def format_column(
7912 self,
7913 column: ColumnElement[Any],
7914 use_table: bool = False,
7915 name: Optional[str] = None,
7916 table_name: Optional[str] = None,
7917 use_schema: bool = False,
7918 anon_map: Optional[Mapping[str, Any]] = None,
7919 ) -> str:
7920 """Prepare a quoted column name."""
7921
7922 if name is None:
7923 name = column.name
7924 assert name is not None
7925
7926 if anon_map is not None and isinstance(
7927 name, elements._truncated_label
7928 ):
7929 name = name.apply_map(anon_map)
7930
7931 if not getattr(column, "is_literal", False):
7932 if use_table:
7933 return (
7934 self.format_table(
7935 column.table, use_schema=use_schema, name=table_name
7936 )
7937 + "."
7938 + self.quote(name)
7939 )
7940 else:
7941 return self.quote(name)
7942 else:
7943 # literal textual elements get stuck into ColumnClause a lot,
7944 # which shouldn't get quoted
7945
7946 if use_table:
7947 return (
7948 self.format_table(
7949 column.table, use_schema=use_schema, name=table_name
7950 )
7951 + "."
7952 + name
7953 )
7954 else:
7955 return name
7956
7957 def format_table_seq(self, table, use_schema=True):
7958 """Format table name and schema as a tuple."""
7959
7960 # Dialects with more levels in their fully qualified references
7961 # ('database', 'owner', etc.) could override this and return
7962 # a longer sequence.
7963
7964 effective_schema = self.schema_for_object(table)
7965
7966 if not self.omit_schema and use_schema and effective_schema:
7967 return (
7968 self.quote_schema(effective_schema),
7969 self.format_table(table, use_schema=False),
7970 )
7971 else:
7972 return (self.format_table(table, use_schema=False),)
7973
7974 @util.memoized_property
7975 def _r_identifiers(self):
7976 initial, final, escaped_final = (
7977 re.escape(s)
7978 for s in (
7979 self.initial_quote,
7980 self.final_quote,
7981 self._escape_identifier(self.final_quote),
7982 )
7983 )
7984 r = re.compile(
7985 r"(?:"
7986 r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
7987 r"|([^\.]+))(?=\.|$))+"
7988 % {"initial": initial, "final": final, "escaped": escaped_final}
7989 )
7990 return r
7991
7992 def unformat_identifiers(self, identifiers: str) -> Sequence[str]:
7993 """Unpack 'schema.table.column'-like strings into components."""
7994
7995 r = self._r_identifiers
7996 return [
7997 self._unescape_identifier(i)
7998 for i in [a or b for a, b in r.findall(identifiers)]
7999 ]