1# sql/compiler.py
2# Copyright (C) 2005-2025 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: https://www.opensource.org/licenses/mit-license.php
7# mypy: allow-untyped-defs, allow-untyped-calls
8
9"""Base SQL and DDL compiler implementations.
10
11Classes provided include:
12
13:class:`.compiler.SQLCompiler` - renders SQL
14strings
15
16:class:`.compiler.DDLCompiler` - renders DDL
17(data definition language) strings
18
19:class:`.compiler.GenericTypeCompiler` - renders
20type specification strings.
21
22To generate user-defined SQL strings, see
23:doc:`/ext/compiler`.
24
25"""
26from __future__ import annotations
27
28import collections
29import collections.abc as collections_abc
30import contextlib
31from enum import IntEnum
32import functools
33import itertools
34import operator
35import re
36from time import perf_counter
37import typing
38from typing import Any
39from typing import Callable
40from typing import cast
41from typing import ClassVar
42from typing import Dict
43from typing import FrozenSet
44from typing import Iterable
45from typing import Iterator
46from typing import List
47from typing import Mapping
48from typing import MutableMapping
49from typing import NamedTuple
50from typing import NoReturn
51from typing import Optional
52from typing import Pattern
53from typing import Sequence
54from typing import Set
55from typing import Tuple
56from typing import Type
57from typing import TYPE_CHECKING
58from typing import Union
59
60from . import base
61from . import coercions
62from . import crud
63from . import elements
64from . import functions
65from . import operators
66from . import roles
67from . import schema
68from . import selectable
69from . import sqltypes
70from . import util as sql_util
71from ._typing import is_column_element
72from ._typing import is_dml
73from .base import _de_clone
74from .base import _from_objects
75from .base import _NONE_NAME
76from .base import _SentinelDefaultCharacterization
77from .base import NO_ARG
78from .elements import quoted_name
79from .sqltypes import TupleType
80from .visitors import prefix_anon_map
81from .. import exc
82from .. import util
83from ..util import FastIntFlag
84from ..util.typing import Literal
85from ..util.typing import Protocol
86from ..util.typing import Self
87from ..util.typing import TypedDict
88
89if typing.TYPE_CHECKING:
90 from .annotation import _AnnotationDict
91 from .base import _AmbiguousTableNameMap
92 from .base import CompileState
93 from .base import Executable
94 from .cache_key import CacheKey
95 from .ddl import ExecutableDDLElement
96 from .dml import Insert
97 from .dml import Update
98 from .dml import UpdateBase
99 from .dml import UpdateDMLState
100 from .dml import ValuesBase
101 from .elements import _truncated_label
102 from .elements import BinaryExpression
103 from .elements import BindParameter
104 from .elements import ClauseElement
105 from .elements import ColumnClause
106 from .elements import ColumnElement
107 from .elements import False_
108 from .elements import Label
109 from .elements import Null
110 from .elements import True_
111 from .functions import Function
112 from .schema import Column
113 from .schema import Constraint
114 from .schema import ForeignKeyConstraint
115 from .schema import Index
116 from .schema import PrimaryKeyConstraint
117 from .schema import Table
118 from .schema import UniqueConstraint
119 from .selectable import _ColumnsClauseElement
120 from .selectable import AliasedReturnsRows
121 from .selectable import CompoundSelectState
122 from .selectable import CTE
123 from .selectable import FromClause
124 from .selectable import NamedFromClause
125 from .selectable import ReturnsRows
126 from .selectable import Select
127 from .selectable import SelectState
128 from .type_api import _BindProcessorType
129 from .type_api import TypeDecorator
130 from .type_api import TypeEngine
131 from .type_api import UserDefinedType
132 from .visitors import Visitable
133 from ..engine.cursor import CursorResultMetaData
134 from ..engine.interfaces import _CoreSingleExecuteParams
135 from ..engine.interfaces import _DBAPIAnyExecuteParams
136 from ..engine.interfaces import _DBAPIMultiExecuteParams
137 from ..engine.interfaces import _DBAPISingleExecuteParams
138 from ..engine.interfaces import _ExecuteOptions
139 from ..engine.interfaces import _GenericSetInputSizesType
140 from ..engine.interfaces import _MutableCoreSingleExecuteParams
141 from ..engine.interfaces import Dialect
142 from ..engine.interfaces import SchemaTranslateMapType
143
144
145_FromHintsType = Dict["FromClause", str]
146
147RESERVED_WORDS = {
148 "all",
149 "analyse",
150 "analyze",
151 "and",
152 "any",
153 "array",
154 "as",
155 "asc",
156 "asymmetric",
157 "authorization",
158 "between",
159 "binary",
160 "both",
161 "case",
162 "cast",
163 "check",
164 "collate",
165 "column",
166 "constraint",
167 "create",
168 "cross",
169 "current_date",
170 "current_role",
171 "current_time",
172 "current_timestamp",
173 "current_user",
174 "default",
175 "deferrable",
176 "desc",
177 "distinct",
178 "do",
179 "else",
180 "end",
181 "except",
182 "false",
183 "for",
184 "foreign",
185 "freeze",
186 "from",
187 "full",
188 "grant",
189 "group",
190 "having",
191 "ilike",
192 "in",
193 "initially",
194 "inner",
195 "intersect",
196 "into",
197 "is",
198 "isnull",
199 "join",
200 "leading",
201 "left",
202 "like",
203 "limit",
204 "localtime",
205 "localtimestamp",
206 "natural",
207 "new",
208 "not",
209 "notnull",
210 "null",
211 "off",
212 "offset",
213 "old",
214 "on",
215 "only",
216 "or",
217 "order",
218 "outer",
219 "overlaps",
220 "placing",
221 "primary",
222 "references",
223 "right",
224 "select",
225 "session_user",
226 "set",
227 "similar",
228 "some",
229 "symmetric",
230 "table",
231 "then",
232 "to",
233 "trailing",
234 "true",
235 "union",
236 "unique",
237 "user",
238 "using",
239 "verbose",
240 "when",
241 "where",
242}
243
244LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
245LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
246ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
247
248FK_ON_DELETE = re.compile(
249 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
250)
251FK_ON_UPDATE = re.compile(
252 r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
253)
254FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
255BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
256BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
257
258_pyformat_template = "%%(%(name)s)s"
259BIND_TEMPLATES = {
260 "pyformat": _pyformat_template,
261 "qmark": "?",
262 "format": "%%s",
263 "numeric": ":[_POSITION]",
264 "numeric_dollar": "$[_POSITION]",
265 "named": ":%(name)s",
266}
267
268
269OPERATORS = {
270 # binary
271 operators.and_: " AND ",
272 operators.or_: " OR ",
273 operators.add: " + ",
274 operators.mul: " * ",
275 operators.sub: " - ",
276 operators.mod: " % ",
277 operators.neg: "-",
278 operators.lt: " < ",
279 operators.le: " <= ",
280 operators.ne: " != ",
281 operators.gt: " > ",
282 operators.ge: " >= ",
283 operators.eq: " = ",
284 operators.is_distinct_from: " IS DISTINCT FROM ",
285 operators.is_not_distinct_from: " IS NOT DISTINCT FROM ",
286 operators.concat_op: " || ",
287 operators.match_op: " MATCH ",
288 operators.not_match_op: " NOT MATCH ",
289 operators.in_op: " IN ",
290 operators.not_in_op: " NOT IN ",
291 operators.comma_op: ", ",
292 operators.from_: " FROM ",
293 operators.as_: " AS ",
294 operators.is_: " IS ",
295 operators.is_not: " IS NOT ",
296 operators.collate: " COLLATE ",
297 # unary
298 operators.exists: "EXISTS ",
299 operators.distinct_op: "DISTINCT ",
300 operators.inv: "NOT ",
301 operators.any_op: "ANY ",
302 operators.all_op: "ALL ",
303 # modifiers
304 operators.desc_op: " DESC",
305 operators.asc_op: " ASC",
306 operators.nulls_first_op: " NULLS FIRST",
307 operators.nulls_last_op: " NULLS LAST",
308 # bitwise
309 operators.bitwise_xor_op: " ^ ",
310 operators.bitwise_or_op: " | ",
311 operators.bitwise_and_op: " & ",
312 operators.bitwise_not_op: "~",
313 operators.bitwise_lshift_op: " << ",
314 operators.bitwise_rshift_op: " >> ",
315}
316
317FUNCTIONS: Dict[Type[Function[Any]], str] = {
318 functions.coalesce: "coalesce",
319 functions.current_date: "CURRENT_DATE",
320 functions.current_time: "CURRENT_TIME",
321 functions.current_timestamp: "CURRENT_TIMESTAMP",
322 functions.current_user: "CURRENT_USER",
323 functions.localtime: "LOCALTIME",
324 functions.localtimestamp: "LOCALTIMESTAMP",
325 functions.random: "random",
326 functions.sysdate: "sysdate",
327 functions.session_user: "SESSION_USER",
328 functions.user: "USER",
329 functions.cube: "CUBE",
330 functions.rollup: "ROLLUP",
331 functions.grouping_sets: "GROUPING SETS",
332}
333
334
335EXTRACT_MAP = {
336 "month": "month",
337 "day": "day",
338 "year": "year",
339 "second": "second",
340 "hour": "hour",
341 "doy": "doy",
342 "minute": "minute",
343 "quarter": "quarter",
344 "dow": "dow",
345 "week": "week",
346 "epoch": "epoch",
347 "milliseconds": "milliseconds",
348 "microseconds": "microseconds",
349 "timezone_hour": "timezone_hour",
350 "timezone_minute": "timezone_minute",
351}
352
353COMPOUND_KEYWORDS = {
354 selectable._CompoundSelectKeyword.UNION: "UNION",
355 selectable._CompoundSelectKeyword.UNION_ALL: "UNION ALL",
356 selectable._CompoundSelectKeyword.EXCEPT: "EXCEPT",
357 selectable._CompoundSelectKeyword.EXCEPT_ALL: "EXCEPT ALL",
358 selectable._CompoundSelectKeyword.INTERSECT: "INTERSECT",
359 selectable._CompoundSelectKeyword.INTERSECT_ALL: "INTERSECT ALL",
360}
361
362
363class ResultColumnsEntry(NamedTuple):
364 """Tracks a column expression that is expected to be represented
365 in the result rows for this statement.
366
367 This normally refers to the columns clause of a SELECT statement
368 but may also refer to a RETURNING clause, as well as for dialect-specific
369 emulations.
370
371 """
372
373 keyname: str
374 """string name that's expected in cursor.description"""
375
376 name: str
377 """column name, may be labeled"""
378
379 objects: Tuple[Any, ...]
380 """sequence of objects that should be able to locate this column
381 in a RowMapping. This is typically string names and aliases
382 as well as Column objects.
383
384 """
385
386 type: TypeEngine[Any]
387 """Datatype to be associated with this column. This is where
388 the "result processing" logic directly links the compiled statement
389 to the rows that come back from the cursor.
390
391 """
392
393
394class _ResultMapAppender(Protocol):
395 def __call__(
396 self,
397 keyname: str,
398 name: str,
399 objects: Sequence[Any],
400 type_: TypeEngine[Any],
401 ) -> None: ...
402
403
404# integer indexes into ResultColumnsEntry used by cursor.py.
405# some profiling showed integer access faster than named tuple
406RM_RENDERED_NAME: Literal[0] = 0
407RM_NAME: Literal[1] = 1
408RM_OBJECTS: Literal[2] = 2
409RM_TYPE: Literal[3] = 3
410
411
412class _BaseCompilerStackEntry(TypedDict):
413 asfrom_froms: Set[FromClause]
414 correlate_froms: Set[FromClause]
415 selectable: ReturnsRows
416
417
418class _CompilerStackEntry(_BaseCompilerStackEntry, total=False):
419 compile_state: CompileState
420 need_result_map_for_nested: bool
421 need_result_map_for_compound: bool
422 select_0: ReturnsRows
423 insert_from_select: Select[Any]
424
425
426class ExpandedState(NamedTuple):
427 """represents state to use when producing "expanded" and
428 "post compile" bound parameters for a statement.
429
430 "expanded" parameters are parameters that are generated at
431 statement execution time to suit a number of parameters passed, the most
432 prominent example being the individual elements inside of an IN expression.
433
434 "post compile" parameters are parameters where the SQL literal value
435 will be rendered into the SQL statement at execution time, rather than
436 being passed as separate parameters to the driver.
437
438 To create an :class:`.ExpandedState` instance, use the
439 :meth:`.SQLCompiler.construct_expanded_state` method on any
440 :class:`.SQLCompiler` instance.
441
442 """
443
444 statement: str
445 """String SQL statement with parameters fully expanded"""
446
447 parameters: _CoreSingleExecuteParams
448 """Parameter dictionary with parameters fully expanded.
449
450 For a statement that uses named parameters, this dictionary will map
451 exactly to the names in the statement. For a statement that uses
452 positional parameters, the :attr:`.ExpandedState.positional_parameters`
453 will yield a tuple with the positional parameter set.
454
455 """
456
457 processors: Mapping[str, _BindProcessorType[Any]]
458 """mapping of bound value processors"""
459
460 positiontup: Optional[Sequence[str]]
461 """Sequence of string names indicating the order of positional
462 parameters"""
463
464 parameter_expansion: Mapping[str, List[str]]
465 """Mapping representing the intermediary link from original parameter
466 name to list of "expanded" parameter names, for those parameters that
467 were expanded."""
468
469 @property
470 def positional_parameters(self) -> Tuple[Any, ...]:
471 """Tuple of positional parameters, for statements that were compiled
472 using a positional paramstyle.
473
474 """
475 if self.positiontup is None:
476 raise exc.InvalidRequestError(
477 "statement does not use a positional paramstyle"
478 )
479 return tuple(self.parameters[key] for key in self.positiontup)
480
481 @property
482 def additional_parameters(self) -> _CoreSingleExecuteParams:
483 """synonym for :attr:`.ExpandedState.parameters`."""
484 return self.parameters
485
486
487class _InsertManyValues(NamedTuple):
488 """represents state to use for executing an "insertmanyvalues" statement.
489
490 The primary consumers of this object are the
491 :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and
492 :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods.
493
494 .. versionadded:: 2.0
495
496 """
497
498 is_default_expr: bool
499 """if True, the statement is of the form
500 ``INSERT INTO TABLE DEFAULT VALUES``, and can't be rewritten as a "batch"
501
502 """
503
504 single_values_expr: str
505 """The rendered "values" clause of the INSERT statement.
506
507 This is typically the parenthesized section e.g. "(?, ?, ?)" or similar.
508 The insertmanyvalues logic uses this string as a search and replace
509 target.
510
511 """
512
513 insert_crud_params: List[crud._CrudParamElementStr]
514 """List of Column / bind names etc. used while rewriting the statement"""
515
516 num_positional_params_counted: int
517 """the number of bound parameters in a single-row statement.
518
519 This count may be larger or smaller than the actual number of columns
520 targeted in the INSERT, as it accommodates for SQL expressions
521 in the values list that may have zero or more parameters embedded
522 within them.
523
524 This count is part of what's used to organize rewritten parameter lists
525 when batching.
526
527 """
528
529 sort_by_parameter_order: bool = False
530 """if the deterministic_returnined_order parameter were used on the
531 insert.
532
533 All of the attributes following this will only be used if this is True.
534
535 """
536
537 includes_upsert_behaviors: bool = False
538 """if True, we have to accommodate for upsert behaviors.
539
540 This will in some cases downgrade "insertmanyvalues" that requests
541 deterministic ordering.
542
543 """
544
545 sentinel_columns: Optional[Sequence[Column[Any]]] = None
546 """List of sentinel columns that were located.
547
548 This list is only here if the INSERT asked for
549 sort_by_parameter_order=True,
550 and dialect-appropriate sentinel columns were located.
551
552 .. versionadded:: 2.0.10
553
554 """
555
556 num_sentinel_columns: int = 0
557 """how many sentinel columns are in the above list, if any.
558
559 This is the same as
560 ``len(sentinel_columns) if sentinel_columns is not None else 0``
561
562 """
563
564 sentinel_param_keys: Optional[Sequence[str]] = None
565 """parameter str keys in each param dictionary / tuple
566 that would link to the client side "sentinel" values for that row, which
567 we can use to match up parameter sets to result rows.
568
569 This is only present if sentinel_columns is present and the INSERT
570 statement actually refers to client side values for these sentinel
571 columns.
572
573 .. versionadded:: 2.0.10
574
575 .. versionchanged:: 2.0.29 - the sequence is now string dictionary keys
576 only, used against the "compiled parameteters" collection before
577 the parameters were converted by bound parameter processors
578
579 """
580
581 implicit_sentinel: bool = False
582 """if True, we have exactly one sentinel column and it uses a server side
583 value, currently has to generate an incrementing integer value.
584
585 The dialect in question would have asserted that it supports receiving
586 these values back and sorting on that value as a means of guaranteeing
587 correlation with the incoming parameter list.
588
589 .. versionadded:: 2.0.10
590
591 """
592
593 embed_values_counter: bool = False
594 """Whether to embed an incrementing integer counter in each parameter
595 set within the VALUES clause as parameters are batched over.
596
597 This is only used for a specific INSERT..SELECT..VALUES..RETURNING syntax
598 where a subquery is used to produce value tuples. Current support
599 includes PostgreSQL, Microsoft SQL Server.
600
601 .. versionadded:: 2.0.10
602
603 """
604
605
606class _InsertManyValuesBatch(NamedTuple):
607 """represents an individual batch SQL statement for insertmanyvalues.
608
609 This is passed through the
610 :meth:`.SQLCompiler._deliver_insertmanyvalues_batches` and
611 :meth:`.DefaultDialect._deliver_insertmanyvalues_batches` methods out
612 to the :class:`.Connection` within the
613 :meth:`.Connection._exec_insertmany_context` method.
614
615 .. versionadded:: 2.0.10
616
617 """
618
619 replaced_statement: str
620 replaced_parameters: _DBAPIAnyExecuteParams
621 processed_setinputsizes: Optional[_GenericSetInputSizesType]
622 batch: Sequence[_DBAPISingleExecuteParams]
623 sentinel_values: Sequence[Tuple[Any, ...]]
624 current_batch_size: int
625 batchnum: int
626 total_batches: int
627 rows_sorted: bool
628 is_downgraded: bool
629
630
631class InsertmanyvaluesSentinelOpts(FastIntFlag):
632 """bitflag enum indicating styles of PK defaults
633 which can work as implicit sentinel columns
634
635 """
636
637 NOT_SUPPORTED = 1
638 AUTOINCREMENT = 2
639 IDENTITY = 4
640 SEQUENCE = 8
641
642 ANY_AUTOINCREMENT = AUTOINCREMENT | IDENTITY | SEQUENCE
643 _SUPPORTED_OR_NOT = NOT_SUPPORTED | ANY_AUTOINCREMENT
644
645 USE_INSERT_FROM_SELECT = 16
646 RENDER_SELECT_COL_CASTS = 64
647
648
649class CompilerState(IntEnum):
650 COMPILING = 0
651 """statement is present, compilation phase in progress"""
652
653 STRING_APPLIED = 1
654 """statement is present, string form of the statement has been applied.
655
656 Additional processors by subclasses may still be pending.
657
658 """
659
660 NO_STATEMENT = 2
661 """compiler does not have a statement to compile, is used
662 for method access"""
663
664
665class Linting(IntEnum):
666 """represent preferences for the 'SQL linting' feature.
667
668 this feature currently includes support for flagging cartesian products
669 in SQL statements.
670
671 """
672
673 NO_LINTING = 0
674 "Disable all linting."
675
676 COLLECT_CARTESIAN_PRODUCTS = 1
677 """Collect data on FROMs and cartesian products and gather into
678 'self.from_linter'"""
679
680 WARN_LINTING = 2
681 "Emit warnings for linters that find problems"
682
683 FROM_LINTING = COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING
684 """Warn for cartesian products; combines COLLECT_CARTESIAN_PRODUCTS
685 and WARN_LINTING"""
686
687
688NO_LINTING, COLLECT_CARTESIAN_PRODUCTS, WARN_LINTING, FROM_LINTING = tuple(
689 Linting
690)
691
692
693class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])):
694 """represents current state for the "cartesian product" detection
695 feature."""
696
697 def lint(self, start=None):
698 froms = self.froms
699 if not froms:
700 return None, None
701
702 edges = set(self.edges)
703 the_rest = set(froms)
704
705 if start is not None:
706 start_with = start
707 the_rest.remove(start_with)
708 else:
709 start_with = the_rest.pop()
710
711 stack = collections.deque([start_with])
712
713 while stack and the_rest:
714 node = stack.popleft()
715 the_rest.discard(node)
716
717 # comparison of nodes in edges here is based on hash equality, as
718 # there are "annotated" elements that match the non-annotated ones.
719 # to remove the need for in-python hash() calls, use native
720 # containment routines (e.g. "node in edge", "edge.index(node)")
721 to_remove = {edge for edge in edges if node in edge}
722
723 # appendleft the node in each edge that is not
724 # the one that matched.
725 stack.extendleft(edge[not edge.index(node)] for edge in to_remove)
726 edges.difference_update(to_remove)
727
728 # FROMS left over? boom
729 if the_rest:
730 return the_rest, start_with
731 else:
732 return None, None
733
734 def warn(self, stmt_type="SELECT"):
735 the_rest, start_with = self.lint()
736
737 # FROMS left over? boom
738 if the_rest:
739 froms = the_rest
740 if froms:
741 template = (
742 "{stmt_type} statement has a cartesian product between "
743 "FROM element(s) {froms} and "
744 'FROM element "{start}". Apply join condition(s) '
745 "between each element to resolve."
746 )
747 froms_str = ", ".join(
748 f'"{self.froms[from_]}"' for from_ in froms
749 )
750 message = template.format(
751 stmt_type=stmt_type,
752 froms=froms_str,
753 start=self.froms[start_with],
754 )
755
756 util.warn(message)
757
758
759class Compiled:
760 """Represent a compiled SQL or DDL expression.
761
762 The ``__str__`` method of the ``Compiled`` object should produce
763 the actual text of the statement. ``Compiled`` objects are
764 specific to their underlying database dialect, and also may
765 or may not be specific to the columns referenced within a
766 particular set of bind parameters. In no case should the
767 ``Compiled`` object be dependent on the actual values of those
768 bind parameters, even though it may reference those values as
769 defaults.
770 """
771
772 statement: Optional[ClauseElement] = None
773 "The statement to compile."
774 string: str = ""
775 "The string representation of the ``statement``"
776
777 state: CompilerState
778 """description of the compiler's state"""
779
780 is_sql = False
781 is_ddl = False
782
783 _cached_metadata: Optional[CursorResultMetaData] = None
784
785 _result_columns: Optional[List[ResultColumnsEntry]] = None
786
787 schema_translate_map: Optional[SchemaTranslateMapType] = None
788
789 execution_options: _ExecuteOptions = util.EMPTY_DICT
790 """
791 Execution options propagated from the statement. In some cases,
792 sub-elements of the statement can modify these.
793 """
794
795 preparer: IdentifierPreparer
796
797 _annotations: _AnnotationDict = util.EMPTY_DICT
798
799 compile_state: Optional[CompileState] = None
800 """Optional :class:`.CompileState` object that maintains additional
801 state used by the compiler.
802
803 Major executable objects such as :class:`_expression.Insert`,
804 :class:`_expression.Update`, :class:`_expression.Delete`,
805 :class:`_expression.Select` will generate this
806 state when compiled in order to calculate additional information about the
807 object. For the top level object that is to be executed, the state can be
808 stored here where it can also have applicability towards result set
809 processing.
810
811 .. versionadded:: 1.4
812
813 """
814
815 dml_compile_state: Optional[CompileState] = None
816 """Optional :class:`.CompileState` assigned at the same point that
817 .isinsert, .isupdate, or .isdelete is assigned.
818
819 This will normally be the same object as .compile_state, with the
820 exception of cases like the :class:`.ORMFromStatementCompileState`
821 object.
822
823 .. versionadded:: 1.4.40
824
825 """
826
827 cache_key: Optional[CacheKey] = None
828 """The :class:`.CacheKey` that was generated ahead of creating this
829 :class:`.Compiled` object.
830
831 This is used for routines that need access to the original
832 :class:`.CacheKey` instance generated when the :class:`.Compiled`
833 instance was first cached, typically in order to reconcile
834 the original list of :class:`.BindParameter` objects with a
835 per-statement list that's generated on each call.
836
837 """
838
839 _gen_time: float
840 """Generation time of this :class:`.Compiled`, used for reporting
841 cache stats."""
842
843 def __init__(
844 self,
845 dialect: Dialect,
846 statement: Optional[ClauseElement],
847 schema_translate_map: Optional[SchemaTranslateMapType] = None,
848 render_schema_translate: bool = False,
849 compile_kwargs: Mapping[str, Any] = util.immutabledict(),
850 ):
851 """Construct a new :class:`.Compiled` object.
852
853 :param dialect: :class:`.Dialect` to compile against.
854
855 :param statement: :class:`_expression.ClauseElement` to be compiled.
856
857 :param schema_translate_map: dictionary of schema names to be
858 translated when forming the resultant SQL
859
860 .. seealso::
861
862 :ref:`schema_translating`
863
864 :param compile_kwargs: additional kwargs that will be
865 passed to the initial call to :meth:`.Compiled.process`.
866
867
868 """
869 self.dialect = dialect
870 self.preparer = self.dialect.identifier_preparer
871 if schema_translate_map:
872 self.schema_translate_map = schema_translate_map
873 self.preparer = self.preparer._with_schema_translate(
874 schema_translate_map
875 )
876
877 if statement is not None:
878 self.state = CompilerState.COMPILING
879 self.statement = statement
880 self.can_execute = statement.supports_execution
881 self._annotations = statement._annotations
882 if self.can_execute:
883 if TYPE_CHECKING:
884 assert isinstance(statement, Executable)
885 self.execution_options = statement._execution_options
886 self.string = self.process(self.statement, **compile_kwargs)
887
888 if render_schema_translate:
889 assert schema_translate_map is not None
890 self.string = self.preparer._render_schema_translates(
891 self.string, schema_translate_map
892 )
893
894 self.state = CompilerState.STRING_APPLIED
895 else:
896 self.state = CompilerState.NO_STATEMENT
897
898 self._gen_time = perf_counter()
899
900 def __init_subclass__(cls) -> None:
901 cls._init_compiler_cls()
902 return super().__init_subclass__()
903
904 @classmethod
905 def _init_compiler_cls(cls):
906 pass
907
908 def _execute_on_connection(
909 self, connection, distilled_params, execution_options
910 ):
911 if self.can_execute:
912 return connection._execute_compiled(
913 self, distilled_params, execution_options
914 )
915 else:
916 raise exc.ObjectNotExecutableError(self.statement)
917
918 def visit_unsupported_compilation(self, element, err, **kw):
919 raise exc.UnsupportedCompilationError(self, type(element)) from err
920
921 @property
922 def sql_compiler(self) -> SQLCompiler:
923 """Return a Compiled that is capable of processing SQL expressions.
924
925 If this compiler is one, it would likely just return 'self'.
926
927 """
928
929 raise NotImplementedError()
930
931 def process(self, obj: Visitable, **kwargs: Any) -> str:
932 return obj._compiler_dispatch(self, **kwargs)
933
934 def __str__(self) -> str:
935 """Return the string text of the generated SQL or DDL."""
936
937 if self.state is CompilerState.STRING_APPLIED:
938 return self.string
939 else:
940 return ""
941
942 def construct_params(
943 self,
944 params: Optional[_CoreSingleExecuteParams] = None,
945 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
946 escape_names: bool = True,
947 ) -> Optional[_MutableCoreSingleExecuteParams]:
948 """Return the bind params for this compiled object.
949
950 :param params: a dict of string/object pairs whose values will
951 override bind values compiled in to the
952 statement.
953 """
954
955 raise NotImplementedError()
956
957 @property
958 def params(self):
959 """Return the bind params for this compiled object."""
960 return self.construct_params()
961
962
963class TypeCompiler(util.EnsureKWArg):
964 """Produces DDL specification for TypeEngine objects."""
965
966 ensure_kwarg = r"visit_\w+"
967
968 def __init__(self, dialect: Dialect):
969 self.dialect = dialect
970
971 def process(self, type_: TypeEngine[Any], **kw: Any) -> str:
972 if (
973 type_._variant_mapping
974 and self.dialect.name in type_._variant_mapping
975 ):
976 type_ = type_._variant_mapping[self.dialect.name]
977 return type_._compiler_dispatch(self, **kw)
978
979 def visit_unsupported_compilation(
980 self, element: Any, err: Exception, **kw: Any
981 ) -> NoReturn:
982 raise exc.UnsupportedCompilationError(self, element) from err
983
984
985# this was a Visitable, but to allow accurate detection of
986# column elements this is actually a column element
987class _CompileLabel(
988 roles.BinaryElementRole[Any], elements.CompilerColumnElement
989):
990 """lightweight label object which acts as an expression.Label."""
991
992 __visit_name__ = "label"
993 __slots__ = "element", "name", "_alt_names"
994
995 def __init__(self, col, name, alt_names=()):
996 self.element = col
997 self.name = name
998 self._alt_names = (col,) + alt_names
999
1000 @property
1001 def proxy_set(self):
1002 return self.element.proxy_set
1003
1004 @property
1005 def type(self):
1006 return self.element.type
1007
1008 def self_group(self, **kw):
1009 return self
1010
1011
1012class ilike_case_insensitive(
1013 roles.BinaryElementRole[Any], elements.CompilerColumnElement
1014):
1015 """produce a wrapping element for a case-insensitive portion of
1016 an ILIKE construct.
1017
1018 The construct usually renders the ``lower()`` function, but on
1019 PostgreSQL will pass silently with the assumption that "ILIKE"
1020 is being used.
1021
1022 .. versionadded:: 2.0
1023
1024 """
1025
1026 __visit_name__ = "ilike_case_insensitive_operand"
1027 __slots__ = "element", "comparator"
1028
1029 def __init__(self, element):
1030 self.element = element
1031 self.comparator = element.comparator
1032
1033 @property
1034 def proxy_set(self):
1035 return self.element.proxy_set
1036
1037 @property
1038 def type(self):
1039 return self.element.type
1040
1041 def self_group(self, **kw):
1042 return self
1043
1044 def _with_binary_element_type(self, type_):
1045 return ilike_case_insensitive(
1046 self.element._with_binary_element_type(type_)
1047 )
1048
1049
1050class SQLCompiler(Compiled):
1051 """Default implementation of :class:`.Compiled`.
1052
1053 Compiles :class:`_expression.ClauseElement` objects into SQL strings.
1054
1055 """
1056
1057 extract_map = EXTRACT_MAP
1058
1059 bindname_escape_characters: ClassVar[Mapping[str, str]] = (
1060 util.immutabledict(
1061 {
1062 "%": "P",
1063 "(": "A",
1064 ")": "Z",
1065 ":": "C",
1066 ".": "_",
1067 "[": "_",
1068 "]": "_",
1069 " ": "_",
1070 }
1071 )
1072 )
1073 """A mapping (e.g. dict or similar) containing a lookup of
1074 characters keyed to replacement characters which will be applied to all
1075 'bind names' used in SQL statements as a form of 'escaping'; the given
1076 characters are replaced entirely with the 'replacement' character when
1077 rendered in the SQL statement, and a similar translation is performed
1078 on the incoming names used in parameter dictionaries passed to methods
1079 like :meth:`_engine.Connection.execute`.
1080
1081 This allows bound parameter names used in :func:`_sql.bindparam` and
1082 other constructs to have any arbitrary characters present without any
1083 concern for characters that aren't allowed at all on the target database.
1084
1085 Third party dialects can establish their own dictionary here to replace the
1086 default mapping, which will ensure that the particular characters in the
1087 mapping will never appear in a bound parameter name.
1088
1089 The dictionary is evaluated at **class creation time**, so cannot be
1090 modified at runtime; it must be present on the class when the class
1091 is first declared.
1092
1093 Note that for dialects that have additional bound parameter rules such
1094 as additional restrictions on leading characters, the
1095 :meth:`_sql.SQLCompiler.bindparam_string` method may need to be augmented.
1096 See the cx_Oracle compiler for an example of this.
1097
1098 .. versionadded:: 2.0.0rc1
1099
1100 """
1101
1102 _bind_translate_re: ClassVar[Pattern[str]]
1103 _bind_translate_chars: ClassVar[Mapping[str, str]]
1104
1105 is_sql = True
1106
1107 compound_keywords = COMPOUND_KEYWORDS
1108
1109 isdelete: bool = False
1110 isinsert: bool = False
1111 isupdate: bool = False
1112 """class-level defaults which can be set at the instance
1113 level to define if this Compiled instance represents
1114 INSERT/UPDATE/DELETE
1115 """
1116
1117 postfetch: Optional[List[Column[Any]]]
1118 """list of columns that can be post-fetched after INSERT or UPDATE to
1119 receive server-updated values"""
1120
1121 insert_prefetch: Sequence[Column[Any]] = ()
1122 """list of columns for which default values should be evaluated before
1123 an INSERT takes place"""
1124
1125 update_prefetch: Sequence[Column[Any]] = ()
1126 """list of columns for which onupdate default values should be evaluated
1127 before an UPDATE takes place"""
1128
1129 implicit_returning: Optional[Sequence[ColumnElement[Any]]] = None
1130 """list of "implicit" returning columns for a toplevel INSERT or UPDATE
1131 statement, used to receive newly generated values of columns.
1132
1133 .. versionadded:: 2.0 ``implicit_returning`` replaces the previous
1134 ``returning`` collection, which was not a generalized RETURNING
1135 collection and instead was in fact specific to the "implicit returning"
1136 feature.
1137
1138 """
1139
1140 isplaintext: bool = False
1141
1142 binds: Dict[str, BindParameter[Any]]
1143 """a dictionary of bind parameter keys to BindParameter instances."""
1144
1145 bind_names: Dict[BindParameter[Any], str]
1146 """a dictionary of BindParameter instances to "compiled" names
1147 that are actually present in the generated SQL"""
1148
1149 stack: List[_CompilerStackEntry]
1150 """major statements such as SELECT, INSERT, UPDATE, DELETE are
1151 tracked in this stack using an entry format."""
1152
1153 returning_precedes_values: bool = False
1154 """set to True classwide to generate RETURNING
1155 clauses before the VALUES or WHERE clause (i.e. MSSQL)
1156 """
1157
1158 render_table_with_column_in_update_from: bool = False
1159 """set to True classwide to indicate the SET clause
1160 in a multi-table UPDATE statement should qualify
1161 columns with the table name (i.e. MySQL only)
1162 """
1163
1164 ansi_bind_rules: bool = False
1165 """SQL 92 doesn't allow bind parameters to be used
1166 in the columns clause of a SELECT, nor does it allow
1167 ambiguous expressions like "? = ?". A compiler
1168 subclass can set this flag to False if the target
1169 driver/DB enforces this
1170 """
1171
1172 bindtemplate: str
1173 """template to render bound parameters based on paramstyle."""
1174
1175 compilation_bindtemplate: str
1176 """template used by compiler to render parameters before positional
1177 paramstyle application"""
1178
1179 _numeric_binds_identifier_char: str
1180 """Character that's used to as the identifier of a numerical bind param.
1181 For example if this char is set to ``$``, numerical binds will be rendered
1182 in the form ``$1, $2, $3``.
1183 """
1184
1185 _result_columns: List[ResultColumnsEntry]
1186 """relates label names in the final SQL to a tuple of local
1187 column/label name, ColumnElement object (if any) and
1188 TypeEngine. CursorResult uses this for type processing and
1189 column targeting"""
1190
1191 _textual_ordered_columns: bool = False
1192 """tell the result object that the column names as rendered are important,
1193 but they are also "ordered" vs. what is in the compiled object here.
1194
1195 As of 1.4.42 this condition is only present when the statement is a
1196 TextualSelect, e.g. text("....").columns(...), where it is required
1197 that the columns are considered positionally and not by name.
1198
1199 """
1200
1201 _ad_hoc_textual: bool = False
1202 """tell the result that we encountered text() or '*' constructs in the
1203 middle of the result columns, but we also have compiled columns, so
1204 if the number of columns in cursor.description does not match how many
1205 expressions we have, that means we can't rely on positional at all and
1206 should match on name.
1207
1208 """
1209
1210 _ordered_columns: bool = True
1211 """
1212 if False, means we can't be sure the list of entries
1213 in _result_columns is actually the rendered order. Usually
1214 True unless using an unordered TextualSelect.
1215 """
1216
1217 _loose_column_name_matching: bool = False
1218 """tell the result object that the SQL statement is textual, wants to match
1219 up to Column objects, and may be using the ._tq_label in the SELECT rather
1220 than the base name.
1221
1222 """
1223
1224 _numeric_binds: bool = False
1225 """
1226 True if paramstyle is "numeric". This paramstyle is trickier than
1227 all the others.
1228
1229 """
1230
1231 _render_postcompile: bool = False
1232 """
1233 whether to render out POSTCOMPILE params during the compile phase.
1234
1235 This attribute is used only for end-user invocation of stmt.compile();
1236 it's never used for actual statement execution, where instead the
1237 dialect internals access and render the internal postcompile structure
1238 directly.
1239
1240 """
1241
1242 _post_compile_expanded_state: Optional[ExpandedState] = None
1243 """When render_postcompile is used, the ``ExpandedState`` used to create
1244 the "expanded" SQL is assigned here, and then used by the ``.params``
1245 accessor and ``.construct_params()`` methods for their return values.
1246
1247 .. versionadded:: 2.0.0rc1
1248
1249 """
1250
1251 _pre_expanded_string: Optional[str] = None
1252 """Stores the original string SQL before 'post_compile' is applied,
1253 for cases where 'post_compile' were used.
1254
1255 """
1256
1257 _pre_expanded_positiontup: Optional[List[str]] = None
1258
1259 _insertmanyvalues: Optional[_InsertManyValues] = None
1260
1261 _insert_crud_params: Optional[crud._CrudParamSequence] = None
1262
1263 literal_execute_params: FrozenSet[BindParameter[Any]] = frozenset()
1264 """bindparameter objects that are rendered as literal values at statement
1265 execution time.
1266
1267 """
1268
1269 post_compile_params: FrozenSet[BindParameter[Any]] = frozenset()
1270 """bindparameter objects that are rendered as bound parameter placeholders
1271 at statement execution time.
1272
1273 """
1274
1275 escaped_bind_names: util.immutabledict[str, str] = util.EMPTY_DICT
1276 """Late escaping of bound parameter names that has to be converted
1277 to the original name when looking in the parameter dictionary.
1278
1279 """
1280
1281 has_out_parameters = False
1282 """if True, there are bindparam() objects that have the isoutparam
1283 flag set."""
1284
1285 postfetch_lastrowid = False
1286 """if True, and this in insert, use cursor.lastrowid to populate
1287 result.inserted_primary_key. """
1288
1289 _cache_key_bind_match: Optional[
1290 Tuple[
1291 Dict[
1292 BindParameter[Any],
1293 List[BindParameter[Any]],
1294 ],
1295 Dict[
1296 str,
1297 BindParameter[Any],
1298 ],
1299 ]
1300 ] = None
1301 """a mapping that will relate the BindParameter object we compile
1302 to those that are part of the extracted collection of parameters
1303 in the cache key, if we were given a cache key.
1304
1305 """
1306
1307 positiontup: Optional[List[str]] = None
1308 """for a compiled construct that uses a positional paramstyle, will be
1309 a sequence of strings, indicating the names of bound parameters in order.
1310
1311 This is used in order to render bound parameters in their correct order,
1312 and is combined with the :attr:`_sql.Compiled.params` dictionary to
1313 render parameters.
1314
1315 This sequence always contains the unescaped name of the parameters.
1316
1317 .. seealso::
1318
1319 :ref:`faq_sql_expression_string` - includes a usage example for
1320 debugging use cases.
1321
1322 """
1323 _values_bindparam: Optional[List[str]] = None
1324
1325 _visited_bindparam: Optional[List[str]] = None
1326
1327 inline: bool = False
1328
1329 ctes: Optional[MutableMapping[CTE, str]]
1330
1331 # Detect same CTE references - Dict[(level, name), cte]
1332 # Level is required for supporting nesting
1333 ctes_by_level_name: Dict[Tuple[int, str], CTE]
1334
1335 # To retrieve key/level in ctes_by_level_name -
1336 # Dict[cte_reference, (level, cte_name, cte_opts)]
1337 level_name_by_cte: Dict[CTE, Tuple[int, str, selectable._CTEOpts]]
1338
1339 ctes_recursive: bool
1340
1341 _post_compile_pattern = re.compile(r"__\[POSTCOMPILE_(\S+?)(~~.+?~~)?\]")
1342 _pyformat_pattern = re.compile(r"%\(([^)]+?)\)s")
1343 _positional_pattern = re.compile(
1344 f"{_pyformat_pattern.pattern}|{_post_compile_pattern.pattern}"
1345 )
1346
1347 @classmethod
1348 def _init_compiler_cls(cls):
1349 cls._init_bind_translate()
1350
1351 @classmethod
1352 def _init_bind_translate(cls):
1353 reg = re.escape("".join(cls.bindname_escape_characters))
1354 cls._bind_translate_re = re.compile(f"[{reg}]")
1355 cls._bind_translate_chars = cls.bindname_escape_characters
1356
1357 def __init__(
1358 self,
1359 dialect: Dialect,
1360 statement: Optional[ClauseElement],
1361 cache_key: Optional[CacheKey] = None,
1362 column_keys: Optional[Sequence[str]] = None,
1363 for_executemany: bool = False,
1364 linting: Linting = NO_LINTING,
1365 _supporting_against: Optional[SQLCompiler] = None,
1366 **kwargs: Any,
1367 ):
1368 """Construct a new :class:`.SQLCompiler` object.
1369
1370 :param dialect: :class:`.Dialect` to be used
1371
1372 :param statement: :class:`_expression.ClauseElement` to be compiled
1373
1374 :param column_keys: a list of column names to be compiled into an
1375 INSERT or UPDATE statement.
1376
1377 :param for_executemany: whether INSERT / UPDATE statements should
1378 expect that they are to be invoked in an "executemany" style,
1379 which may impact how the statement will be expected to return the
1380 values of defaults and autoincrement / sequences and similar.
1381 Depending on the backend and driver in use, support for retrieving
1382 these values may be disabled which means SQL expressions may
1383 be rendered inline, RETURNING may not be rendered, etc.
1384
1385 :param kwargs: additional keyword arguments to be consumed by the
1386 superclass.
1387
1388 """
1389 self.column_keys = column_keys
1390
1391 self.cache_key = cache_key
1392
1393 if cache_key:
1394 cksm = {b.key: b for b in cache_key[1]}
1395 ckbm = {b: [b] for b in cache_key[1]}
1396 self._cache_key_bind_match = (ckbm, cksm)
1397
1398 # compile INSERT/UPDATE defaults/sequences to expect executemany
1399 # style execution, which may mean no pre-execute of defaults,
1400 # or no RETURNING
1401 self.for_executemany = for_executemany
1402
1403 self.linting = linting
1404
1405 # a dictionary of bind parameter keys to BindParameter
1406 # instances.
1407 self.binds = {}
1408
1409 # a dictionary of BindParameter instances to "compiled" names
1410 # that are actually present in the generated SQL
1411 self.bind_names = util.column_dict()
1412
1413 # stack which keeps track of nested SELECT statements
1414 self.stack = []
1415
1416 self._result_columns = []
1417
1418 # true if the paramstyle is positional
1419 self.positional = dialect.positional
1420 if self.positional:
1421 self._numeric_binds = nb = dialect.paramstyle.startswith("numeric")
1422 if nb:
1423 self._numeric_binds_identifier_char = (
1424 "$" if dialect.paramstyle == "numeric_dollar" else ":"
1425 )
1426
1427 self.compilation_bindtemplate = _pyformat_template
1428 else:
1429 self.compilation_bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
1430
1431 self.ctes = None
1432
1433 self.label_length = (
1434 dialect.label_length or dialect.max_identifier_length
1435 )
1436
1437 # a map which tracks "anonymous" identifiers that are created on
1438 # the fly here
1439 self.anon_map = prefix_anon_map()
1440
1441 # a map which tracks "truncated" names based on
1442 # dialect.label_length or dialect.max_identifier_length
1443 self.truncated_names: Dict[Tuple[str, str], str] = {}
1444 self._truncated_counters: Dict[str, int] = {}
1445
1446 Compiled.__init__(self, dialect, statement, **kwargs)
1447
1448 if self.isinsert or self.isupdate or self.isdelete:
1449 if TYPE_CHECKING:
1450 assert isinstance(statement, UpdateBase)
1451
1452 if self.isinsert or self.isupdate:
1453 if TYPE_CHECKING:
1454 assert isinstance(statement, ValuesBase)
1455 if statement._inline:
1456 self.inline = True
1457 elif self.for_executemany and (
1458 not self.isinsert
1459 or (
1460 self.dialect.insert_executemany_returning
1461 and statement._return_defaults
1462 )
1463 ):
1464 self.inline = True
1465
1466 self.bindtemplate = BIND_TEMPLATES[dialect.paramstyle]
1467
1468 if _supporting_against:
1469 self.__dict__.update(
1470 {
1471 k: v
1472 for k, v in _supporting_against.__dict__.items()
1473 if k
1474 not in {
1475 "state",
1476 "dialect",
1477 "preparer",
1478 "positional",
1479 "_numeric_binds",
1480 "compilation_bindtemplate",
1481 "bindtemplate",
1482 }
1483 }
1484 )
1485
1486 if self.state is CompilerState.STRING_APPLIED:
1487 if self.positional:
1488 if self._numeric_binds:
1489 self._process_numeric()
1490 else:
1491 self._process_positional()
1492
1493 if self._render_postcompile:
1494 parameters = self.construct_params(
1495 escape_names=False,
1496 _no_postcompile=True,
1497 )
1498
1499 self._process_parameters_for_postcompile(
1500 parameters, _populate_self=True
1501 )
1502
1503 @property
1504 def insert_single_values_expr(self) -> Optional[str]:
1505 """When an INSERT is compiled with a single set of parameters inside
1506 a VALUES expression, the string is assigned here, where it can be
1507 used for insert batching schemes to rewrite the VALUES expression.
1508
1509 .. versionadded:: 1.3.8
1510
1511 .. versionchanged:: 2.0 This collection is no longer used by
1512 SQLAlchemy's built-in dialects, in favor of the currently
1513 internal ``_insertmanyvalues`` collection that is used only by
1514 :class:`.SQLCompiler`.
1515
1516 """
1517 if self._insertmanyvalues is None:
1518 return None
1519 else:
1520 return self._insertmanyvalues.single_values_expr
1521
1522 @util.ro_memoized_property
1523 def effective_returning(self) -> Optional[Sequence[ColumnElement[Any]]]:
1524 """The effective "returning" columns for INSERT, UPDATE or DELETE.
1525
1526 This is either the so-called "implicit returning" columns which are
1527 calculated by the compiler on the fly, or those present based on what's
1528 present in ``self.statement._returning`` (expanded into individual
1529 columns using the ``._all_selected_columns`` attribute) i.e. those set
1530 explicitly using the :meth:`.UpdateBase.returning` method.
1531
1532 .. versionadded:: 2.0
1533
1534 """
1535 if self.implicit_returning:
1536 return self.implicit_returning
1537 elif self.statement is not None and is_dml(self.statement):
1538 return [
1539 c
1540 for c in self.statement._all_selected_columns
1541 if is_column_element(c)
1542 ]
1543
1544 else:
1545 return None
1546
1547 @property
1548 def returning(self):
1549 """backwards compatibility; returns the
1550 effective_returning collection.
1551
1552 """
1553 return self.effective_returning
1554
1555 @property
1556 def current_executable(self):
1557 """Return the current 'executable' that is being compiled.
1558
1559 This is currently the :class:`_sql.Select`, :class:`_sql.Insert`,
1560 :class:`_sql.Update`, :class:`_sql.Delete`,
1561 :class:`_sql.CompoundSelect` object that is being compiled.
1562 Specifically it's assigned to the ``self.stack`` list of elements.
1563
1564 When a statement like the above is being compiled, it normally
1565 is also assigned to the ``.statement`` attribute of the
1566 :class:`_sql.Compiler` object. However, all SQL constructs are
1567 ultimately nestable, and this attribute should never be consulted
1568 by a ``visit_`` method, as it is not guaranteed to be assigned
1569 nor guaranteed to correspond to the current statement being compiled.
1570
1571 .. versionadded:: 1.3.21
1572
1573 For compatibility with previous versions, use the following
1574 recipe::
1575
1576 statement = getattr(self, "current_executable", False)
1577 if statement is False:
1578 statement = self.stack[-1]["selectable"]
1579
1580 For versions 1.4 and above, ensure only .current_executable
1581 is used; the format of "self.stack" may change.
1582
1583
1584 """
1585 try:
1586 return self.stack[-1]["selectable"]
1587 except IndexError as ie:
1588 raise IndexError("Compiler does not have a stack entry") from ie
1589
1590 @property
1591 def prefetch(self):
1592 return list(self.insert_prefetch) + list(self.update_prefetch)
1593
1594 @util.memoized_property
1595 def _global_attributes(self) -> Dict[Any, Any]:
1596 return {}
1597
1598 @util.memoized_instancemethod
1599 def _init_cte_state(self) -> MutableMapping[CTE, str]:
1600 """Initialize collections related to CTEs only if
1601 a CTE is located, to save on the overhead of
1602 these collections otherwise.
1603
1604 """
1605 # collect CTEs to tack on top of a SELECT
1606 # To store the query to print - Dict[cte, text_query]
1607 ctes: MutableMapping[CTE, str] = util.OrderedDict()
1608 self.ctes = ctes
1609
1610 # Detect same CTE references - Dict[(level, name), cte]
1611 # Level is required for supporting nesting
1612 self.ctes_by_level_name = {}
1613
1614 # To retrieve key/level in ctes_by_level_name -
1615 # Dict[cte_reference, (level, cte_name, cte_opts)]
1616 self.level_name_by_cte = {}
1617
1618 self.ctes_recursive = False
1619
1620 return ctes
1621
1622 @contextlib.contextmanager
1623 def _nested_result(self):
1624 """special API to support the use case of 'nested result sets'"""
1625 result_columns, ordered_columns = (
1626 self._result_columns,
1627 self._ordered_columns,
1628 )
1629 self._result_columns, self._ordered_columns = [], False
1630
1631 try:
1632 if self.stack:
1633 entry = self.stack[-1]
1634 entry["need_result_map_for_nested"] = True
1635 else:
1636 entry = None
1637 yield self._result_columns, self._ordered_columns
1638 finally:
1639 if entry:
1640 entry.pop("need_result_map_for_nested")
1641 self._result_columns, self._ordered_columns = (
1642 result_columns,
1643 ordered_columns,
1644 )
1645
1646 def _process_positional(self):
1647 assert not self.positiontup
1648 assert self.state is CompilerState.STRING_APPLIED
1649 assert not self._numeric_binds
1650
1651 if self.dialect.paramstyle == "format":
1652 placeholder = "%s"
1653 else:
1654 assert self.dialect.paramstyle == "qmark"
1655 placeholder = "?"
1656
1657 positions = []
1658
1659 def find_position(m: re.Match[str]) -> str:
1660 normal_bind = m.group(1)
1661 if normal_bind:
1662 positions.append(normal_bind)
1663 return placeholder
1664 else:
1665 # this a post-compile bind
1666 positions.append(m.group(2))
1667 return m.group(0)
1668
1669 self.string = re.sub(
1670 self._positional_pattern, find_position, self.string
1671 )
1672
1673 if self.escaped_bind_names:
1674 reverse_escape = {v: k for k, v in self.escaped_bind_names.items()}
1675 assert len(self.escaped_bind_names) == len(reverse_escape)
1676 self.positiontup = [
1677 reverse_escape.get(name, name) for name in positions
1678 ]
1679 else:
1680 self.positiontup = positions
1681
1682 if self._insertmanyvalues:
1683 positions = []
1684
1685 single_values_expr = re.sub(
1686 self._positional_pattern,
1687 find_position,
1688 self._insertmanyvalues.single_values_expr,
1689 )
1690 insert_crud_params = [
1691 (
1692 v[0],
1693 v[1],
1694 re.sub(self._positional_pattern, find_position, v[2]),
1695 v[3],
1696 )
1697 for v in self._insertmanyvalues.insert_crud_params
1698 ]
1699
1700 self._insertmanyvalues = self._insertmanyvalues._replace(
1701 single_values_expr=single_values_expr,
1702 insert_crud_params=insert_crud_params,
1703 )
1704
1705 def _process_numeric(self):
1706 assert self._numeric_binds
1707 assert self.state is CompilerState.STRING_APPLIED
1708
1709 num = 1
1710 param_pos: Dict[str, str] = {}
1711 order: Iterable[str]
1712 if self._insertmanyvalues and self._values_bindparam is not None:
1713 # bindparams that are not in values are always placed first.
1714 # this avoids the need of changing them when using executemany
1715 # values () ()
1716 order = itertools.chain(
1717 (
1718 name
1719 for name in self.bind_names.values()
1720 if name not in self._values_bindparam
1721 ),
1722 self.bind_names.values(),
1723 )
1724 else:
1725 order = self.bind_names.values()
1726
1727 for bind_name in order:
1728 if bind_name in param_pos:
1729 continue
1730 bind = self.binds[bind_name]
1731 if (
1732 bind in self.post_compile_params
1733 or bind in self.literal_execute_params
1734 ):
1735 # set to None to just mark the in positiontup, it will not
1736 # be replaced below.
1737 param_pos[bind_name] = None # type: ignore
1738 else:
1739 ph = f"{self._numeric_binds_identifier_char}{num}"
1740 num += 1
1741 param_pos[bind_name] = ph
1742
1743 self.next_numeric_pos = num
1744
1745 self.positiontup = list(param_pos)
1746 if self.escaped_bind_names:
1747 len_before = len(param_pos)
1748 param_pos = {
1749 self.escaped_bind_names.get(name, name): pos
1750 for name, pos in param_pos.items()
1751 }
1752 assert len(param_pos) == len_before
1753
1754 # Can't use format here since % chars are not escaped.
1755 self.string = self._pyformat_pattern.sub(
1756 lambda m: param_pos[m.group(1)], self.string
1757 )
1758
1759 if self._insertmanyvalues:
1760 single_values_expr = (
1761 # format is ok here since single_values_expr includes only
1762 # place-holders
1763 self._insertmanyvalues.single_values_expr
1764 % param_pos
1765 )
1766 insert_crud_params = [
1767 (v[0], v[1], "%s", v[3])
1768 for v in self._insertmanyvalues.insert_crud_params
1769 ]
1770
1771 self._insertmanyvalues = self._insertmanyvalues._replace(
1772 # This has the numbers (:1, :2)
1773 single_values_expr=single_values_expr,
1774 # The single binds are instead %s so they can be formatted
1775 insert_crud_params=insert_crud_params,
1776 )
1777
1778 @util.memoized_property
1779 def _bind_processors(
1780 self,
1781 ) -> MutableMapping[
1782 str, Union[_BindProcessorType[Any], Sequence[_BindProcessorType[Any]]]
1783 ]:
1784 # mypy is not able to see the two value types as the above Union,
1785 # it just sees "object". don't know how to resolve
1786 return {
1787 key: value # type: ignore
1788 for key, value in (
1789 (
1790 self.bind_names[bindparam],
1791 (
1792 bindparam.type._cached_bind_processor(self.dialect)
1793 if not bindparam.type._is_tuple_type
1794 else tuple(
1795 elem_type._cached_bind_processor(self.dialect)
1796 for elem_type in cast(
1797 TupleType, bindparam.type
1798 ).types
1799 )
1800 ),
1801 )
1802 for bindparam in self.bind_names
1803 )
1804 if value is not None
1805 }
1806
1807 def is_subquery(self):
1808 return len(self.stack) > 1
1809
1810 @property
1811 def sql_compiler(self) -> Self:
1812 return self
1813
1814 def construct_expanded_state(
1815 self,
1816 params: Optional[_CoreSingleExecuteParams] = None,
1817 escape_names: bool = True,
1818 ) -> ExpandedState:
1819 """Return a new :class:`.ExpandedState` for a given parameter set.
1820
1821 For queries that use "expanding" or other late-rendered parameters,
1822 this method will provide for both the finalized SQL string as well
1823 as the parameters that would be used for a particular parameter set.
1824
1825 .. versionadded:: 2.0.0rc1
1826
1827 """
1828 parameters = self.construct_params(
1829 params,
1830 escape_names=escape_names,
1831 _no_postcompile=True,
1832 )
1833 return self._process_parameters_for_postcompile(
1834 parameters,
1835 )
1836
1837 def construct_params(
1838 self,
1839 params: Optional[_CoreSingleExecuteParams] = None,
1840 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
1841 escape_names: bool = True,
1842 _group_number: Optional[int] = None,
1843 _check: bool = True,
1844 _no_postcompile: bool = False,
1845 ) -> _MutableCoreSingleExecuteParams:
1846 """return a dictionary of bind parameter keys and values"""
1847
1848 if self._render_postcompile and not _no_postcompile:
1849 assert self._post_compile_expanded_state is not None
1850 if not params:
1851 return dict(self._post_compile_expanded_state.parameters)
1852 else:
1853 raise exc.InvalidRequestError(
1854 "can't construct new parameters when render_postcompile "
1855 "is used; the statement is hard-linked to the original "
1856 "parameters. Use construct_expanded_state to generate a "
1857 "new statement and parameters."
1858 )
1859
1860 has_escaped_names = escape_names and bool(self.escaped_bind_names)
1861
1862 if extracted_parameters:
1863 # related the bound parameters collected in the original cache key
1864 # to those collected in the incoming cache key. They will not have
1865 # matching names but they will line up positionally in the same
1866 # way. The parameters present in self.bind_names may be clones of
1867 # these original cache key params in the case of DML but the .key
1868 # will be guaranteed to match.
1869 if self.cache_key is None:
1870 raise exc.CompileError(
1871 "This compiled object has no original cache key; "
1872 "can't pass extracted_parameters to construct_params"
1873 )
1874 else:
1875 orig_extracted = self.cache_key[1]
1876
1877 ckbm_tuple = self._cache_key_bind_match
1878 assert ckbm_tuple is not None
1879 ckbm, _ = ckbm_tuple
1880 resolved_extracted = {
1881 bind: extracted
1882 for b, extracted in zip(orig_extracted, extracted_parameters)
1883 for bind in ckbm[b]
1884 }
1885 else:
1886 resolved_extracted = None
1887
1888 if params:
1889 pd = {}
1890 for bindparam, name in self.bind_names.items():
1891 escaped_name = (
1892 self.escaped_bind_names.get(name, name)
1893 if has_escaped_names
1894 else name
1895 )
1896
1897 if bindparam.key in params:
1898 pd[escaped_name] = params[bindparam.key]
1899 elif name in params:
1900 pd[escaped_name] = params[name]
1901
1902 elif _check and bindparam.required:
1903 if _group_number:
1904 raise exc.InvalidRequestError(
1905 "A value is required for bind parameter %r, "
1906 "in parameter group %d"
1907 % (bindparam.key, _group_number),
1908 code="cd3x",
1909 )
1910 else:
1911 raise exc.InvalidRequestError(
1912 "A value is required for bind parameter %r"
1913 % bindparam.key,
1914 code="cd3x",
1915 )
1916 else:
1917 if resolved_extracted:
1918 value_param = resolved_extracted.get(
1919 bindparam, bindparam
1920 )
1921 else:
1922 value_param = bindparam
1923
1924 if bindparam.callable:
1925 pd[escaped_name] = value_param.effective_value
1926 else:
1927 pd[escaped_name] = value_param.value
1928 return pd
1929 else:
1930 pd = {}
1931 for bindparam, name in self.bind_names.items():
1932 escaped_name = (
1933 self.escaped_bind_names.get(name, name)
1934 if has_escaped_names
1935 else name
1936 )
1937
1938 if _check and bindparam.required:
1939 if _group_number:
1940 raise exc.InvalidRequestError(
1941 "A value is required for bind parameter %r, "
1942 "in parameter group %d"
1943 % (bindparam.key, _group_number),
1944 code="cd3x",
1945 )
1946 else:
1947 raise exc.InvalidRequestError(
1948 "A value is required for bind parameter %r"
1949 % bindparam.key,
1950 code="cd3x",
1951 )
1952
1953 if resolved_extracted:
1954 value_param = resolved_extracted.get(bindparam, bindparam)
1955 else:
1956 value_param = bindparam
1957
1958 if bindparam.callable:
1959 pd[escaped_name] = value_param.effective_value
1960 else:
1961 pd[escaped_name] = value_param.value
1962
1963 return pd
1964
1965 @util.memoized_instancemethod
1966 def _get_set_input_sizes_lookup(self):
1967 dialect = self.dialect
1968
1969 include_types = dialect.include_set_input_sizes
1970 exclude_types = dialect.exclude_set_input_sizes
1971
1972 dbapi = dialect.dbapi
1973
1974 def lookup_type(typ):
1975 dbtype = typ._unwrapped_dialect_impl(dialect).get_dbapi_type(dbapi)
1976
1977 if (
1978 dbtype is not None
1979 and (exclude_types is None or dbtype not in exclude_types)
1980 and (include_types is None or dbtype in include_types)
1981 ):
1982 return dbtype
1983 else:
1984 return None
1985
1986 inputsizes = {}
1987
1988 literal_execute_params = self.literal_execute_params
1989
1990 for bindparam in self.bind_names:
1991 if bindparam in literal_execute_params:
1992 continue
1993
1994 if bindparam.type._is_tuple_type:
1995 inputsizes[bindparam] = [
1996 lookup_type(typ)
1997 for typ in cast(TupleType, bindparam.type).types
1998 ]
1999 else:
2000 inputsizes[bindparam] = lookup_type(bindparam.type)
2001
2002 return inputsizes
2003
2004 @property
2005 def params(self):
2006 """Return the bind param dictionary embedded into this
2007 compiled object, for those values that are present.
2008
2009 .. seealso::
2010
2011 :ref:`faq_sql_expression_string` - includes a usage example for
2012 debugging use cases.
2013
2014 """
2015 return self.construct_params(_check=False)
2016
2017 def _process_parameters_for_postcompile(
2018 self,
2019 parameters: _MutableCoreSingleExecuteParams,
2020 _populate_self: bool = False,
2021 ) -> ExpandedState:
2022 """handle special post compile parameters.
2023
2024 These include:
2025
2026 * "expanding" parameters -typically IN tuples that are rendered
2027 on a per-parameter basis for an otherwise fixed SQL statement string.
2028
2029 * literal_binds compiled with the literal_execute flag. Used for
2030 things like SQL Server "TOP N" where the driver does not accommodate
2031 N as a bound parameter.
2032
2033 """
2034
2035 expanded_parameters = {}
2036 new_positiontup: Optional[List[str]]
2037
2038 pre_expanded_string = self._pre_expanded_string
2039 if pre_expanded_string is None:
2040 pre_expanded_string = self.string
2041
2042 if self.positional:
2043 new_positiontup = []
2044
2045 pre_expanded_positiontup = self._pre_expanded_positiontup
2046 if pre_expanded_positiontup is None:
2047 pre_expanded_positiontup = self.positiontup
2048
2049 else:
2050 new_positiontup = pre_expanded_positiontup = None
2051
2052 processors = self._bind_processors
2053 single_processors = cast(
2054 "Mapping[str, _BindProcessorType[Any]]", processors
2055 )
2056 tuple_processors = cast(
2057 "Mapping[str, Sequence[_BindProcessorType[Any]]]", processors
2058 )
2059
2060 new_processors: Dict[str, _BindProcessorType[Any]] = {}
2061
2062 replacement_expressions: Dict[str, Any] = {}
2063 to_update_sets: Dict[str, Any] = {}
2064
2065 # notes:
2066 # *unescaped* parameter names in:
2067 # self.bind_names, self.binds, self._bind_processors, self.positiontup
2068 #
2069 # *escaped* parameter names in:
2070 # construct_params(), replacement_expressions
2071
2072 numeric_positiontup: Optional[List[str]] = None
2073
2074 if self.positional and pre_expanded_positiontup is not None:
2075 names: Iterable[str] = pre_expanded_positiontup
2076 if self._numeric_binds:
2077 numeric_positiontup = []
2078 else:
2079 names = self.bind_names.values()
2080
2081 ebn = self.escaped_bind_names
2082 for name in names:
2083 escaped_name = ebn.get(name, name) if ebn else name
2084 parameter = self.binds[name]
2085
2086 if parameter in self.literal_execute_params:
2087 if escaped_name not in replacement_expressions:
2088 replacement_expressions[escaped_name] = (
2089 self.render_literal_bindparam(
2090 parameter,
2091 render_literal_value=parameters.pop(escaped_name),
2092 )
2093 )
2094 continue
2095
2096 if parameter in self.post_compile_params:
2097 if escaped_name in replacement_expressions:
2098 to_update = to_update_sets[escaped_name]
2099 values = None
2100 else:
2101 # we are removing the parameter from parameters
2102 # because it is a list value, which is not expected by
2103 # TypeEngine objects that would otherwise be asked to
2104 # process it. the single name is being replaced with
2105 # individual numbered parameters for each value in the
2106 # param.
2107 #
2108 # note we are also inserting *escaped* parameter names
2109 # into the given dictionary. default dialect will
2110 # use these param names directly as they will not be
2111 # in the escaped_bind_names dictionary.
2112 values = parameters.pop(name)
2113
2114 leep_res = self._literal_execute_expanding_parameter(
2115 escaped_name, parameter, values
2116 )
2117 (to_update, replacement_expr) = leep_res
2118
2119 to_update_sets[escaped_name] = to_update
2120 replacement_expressions[escaped_name] = replacement_expr
2121
2122 if not parameter.literal_execute:
2123 parameters.update(to_update)
2124 if parameter.type._is_tuple_type:
2125 assert values is not None
2126 new_processors.update(
2127 (
2128 "%s_%s_%s" % (name, i, j),
2129 tuple_processors[name][j - 1],
2130 )
2131 for i, tuple_element in enumerate(values, 1)
2132 for j, _ in enumerate(tuple_element, 1)
2133 if name in tuple_processors
2134 and tuple_processors[name][j - 1] is not None
2135 )
2136 else:
2137 new_processors.update(
2138 (key, single_processors[name])
2139 for key, _ in to_update
2140 if name in single_processors
2141 )
2142 if numeric_positiontup is not None:
2143 numeric_positiontup.extend(
2144 name for name, _ in to_update
2145 )
2146 elif new_positiontup is not None:
2147 # to_update has escaped names, but that's ok since
2148 # these are new names, that aren't in the
2149 # escaped_bind_names dict.
2150 new_positiontup.extend(name for name, _ in to_update)
2151 expanded_parameters[name] = [
2152 expand_key for expand_key, _ in to_update
2153 ]
2154 elif new_positiontup is not None:
2155 new_positiontup.append(name)
2156
2157 def process_expanding(m):
2158 key = m.group(1)
2159 expr = replacement_expressions[key]
2160
2161 # if POSTCOMPILE included a bind_expression, render that
2162 # around each element
2163 if m.group(2):
2164 tok = m.group(2).split("~~")
2165 be_left, be_right = tok[1], tok[3]
2166 expr = ", ".join(
2167 "%s%s%s" % (be_left, exp, be_right)
2168 for exp in expr.split(", ")
2169 )
2170 return expr
2171
2172 statement = re.sub(
2173 self._post_compile_pattern, process_expanding, pre_expanded_string
2174 )
2175
2176 if numeric_positiontup is not None:
2177 assert new_positiontup is not None
2178 param_pos = {
2179 key: f"{self._numeric_binds_identifier_char}{num}"
2180 for num, key in enumerate(
2181 numeric_positiontup, self.next_numeric_pos
2182 )
2183 }
2184 # Can't use format here since % chars are not escaped.
2185 statement = self._pyformat_pattern.sub(
2186 lambda m: param_pos[m.group(1)], statement
2187 )
2188 new_positiontup.extend(numeric_positiontup)
2189
2190 expanded_state = ExpandedState(
2191 statement,
2192 parameters,
2193 new_processors,
2194 new_positiontup,
2195 expanded_parameters,
2196 )
2197
2198 if _populate_self:
2199 # this is for the "render_postcompile" flag, which is not
2200 # otherwise used internally and is for end-user debugging and
2201 # special use cases.
2202 self._pre_expanded_string = pre_expanded_string
2203 self._pre_expanded_positiontup = pre_expanded_positiontup
2204 self.string = expanded_state.statement
2205 self.positiontup = (
2206 list(expanded_state.positiontup or ())
2207 if self.positional
2208 else None
2209 )
2210 self._post_compile_expanded_state = expanded_state
2211
2212 return expanded_state
2213
2214 @util.preload_module("sqlalchemy.engine.cursor")
2215 def _create_result_map(self):
2216 """utility method used for unit tests only."""
2217 cursor = util.preloaded.engine_cursor
2218 return cursor.CursorResultMetaData._create_description_match_map(
2219 self._result_columns
2220 )
2221
2222 # assigned by crud.py for insert/update statements
2223 _get_bind_name_for_col: _BindNameForColProtocol
2224
2225 @util.memoized_property
2226 def _within_exec_param_key_getter(self) -> Callable[[Any], str]:
2227 getter = self._get_bind_name_for_col
2228 return getter
2229
2230 @util.memoized_property
2231 @util.preload_module("sqlalchemy.engine.result")
2232 def _inserted_primary_key_from_lastrowid_getter(self):
2233 result = util.preloaded.engine_result
2234
2235 param_key_getter = self._within_exec_param_key_getter
2236
2237 assert self.compile_state is not None
2238 statement = self.compile_state.statement
2239
2240 if TYPE_CHECKING:
2241 assert isinstance(statement, Insert)
2242
2243 table = statement.table
2244
2245 getters = [
2246 (operator.methodcaller("get", param_key_getter(col), None), col)
2247 for col in table.primary_key
2248 ]
2249
2250 autoinc_getter = None
2251 autoinc_col = table._autoincrement_column
2252 if autoinc_col is not None:
2253 # apply type post processors to the lastrowid
2254 lastrowid_processor = autoinc_col.type._cached_result_processor(
2255 self.dialect, None
2256 )
2257 autoinc_key = param_key_getter(autoinc_col)
2258
2259 # if a bind value is present for the autoincrement column
2260 # in the parameters, we need to do the logic dictated by
2261 # #7998; honor a non-None user-passed parameter over lastrowid.
2262 # previously in the 1.4 series we weren't fetching lastrowid
2263 # at all if the key were present in the parameters
2264 if autoinc_key in self.binds:
2265
2266 def _autoinc_getter(lastrowid, parameters):
2267 param_value = parameters.get(autoinc_key, lastrowid)
2268 if param_value is not None:
2269 # they supplied non-None parameter, use that.
2270 # SQLite at least is observed to return the wrong
2271 # cursor.lastrowid for INSERT..ON CONFLICT so it
2272 # can't be used in all cases
2273 return param_value
2274 else:
2275 # use lastrowid
2276 return lastrowid
2277
2278 # work around mypy https://github.com/python/mypy/issues/14027
2279 autoinc_getter = _autoinc_getter
2280
2281 else:
2282 lastrowid_processor = None
2283
2284 row_fn = result.result_tuple([col.key for col in table.primary_key])
2285
2286 def get(lastrowid, parameters):
2287 """given cursor.lastrowid value and the parameters used for INSERT,
2288 return a "row" that represents the primary key, either by
2289 using the "lastrowid" or by extracting values from the parameters
2290 that were sent along with the INSERT.
2291
2292 """
2293 if lastrowid_processor is not None:
2294 lastrowid = lastrowid_processor(lastrowid)
2295
2296 if lastrowid is None:
2297 return row_fn(getter(parameters) for getter, col in getters)
2298 else:
2299 return row_fn(
2300 (
2301 (
2302 autoinc_getter(lastrowid, parameters)
2303 if autoinc_getter is not None
2304 else lastrowid
2305 )
2306 if col is autoinc_col
2307 else getter(parameters)
2308 )
2309 for getter, col in getters
2310 )
2311
2312 return get
2313
2314 @util.memoized_property
2315 @util.preload_module("sqlalchemy.engine.result")
2316 def _inserted_primary_key_from_returning_getter(self):
2317 if typing.TYPE_CHECKING:
2318 from ..engine import result
2319 else:
2320 result = util.preloaded.engine_result
2321
2322 assert self.compile_state is not None
2323 statement = self.compile_state.statement
2324
2325 if TYPE_CHECKING:
2326 assert isinstance(statement, Insert)
2327
2328 param_key_getter = self._within_exec_param_key_getter
2329 table = statement.table
2330
2331 returning = self.implicit_returning
2332 assert returning is not None
2333 ret = {col: idx for idx, col in enumerate(returning)}
2334
2335 getters = cast(
2336 "List[Tuple[Callable[[Any], Any], bool]]",
2337 [
2338 (
2339 (operator.itemgetter(ret[col]), True)
2340 if col in ret
2341 else (
2342 operator.methodcaller(
2343 "get", param_key_getter(col), None
2344 ),
2345 False,
2346 )
2347 )
2348 for col in table.primary_key
2349 ],
2350 )
2351
2352 row_fn = result.result_tuple([col.key for col in table.primary_key])
2353
2354 def get(row, parameters):
2355 return row_fn(
2356 getter(row) if use_row else getter(parameters)
2357 for getter, use_row in getters
2358 )
2359
2360 return get
2361
2362 def default_from(self) -> str:
2363 """Called when a SELECT statement has no froms, and no FROM clause is
2364 to be appended.
2365
2366 Gives Oracle Database a chance to tack on a ``FROM DUAL`` to the string
2367 output.
2368
2369 """
2370 return ""
2371
2372 def visit_override_binds(self, override_binds, **kw):
2373 """SQL compile the nested element of an _OverrideBinds with
2374 bindparams swapped out.
2375
2376 The _OverrideBinds is not normally expected to be compiled; it
2377 is meant to be used when an already cached statement is to be used,
2378 the compilation was already performed, and only the bound params should
2379 be swapped in at execution time.
2380
2381 However, there are test cases that exericise this object, and
2382 additionally the ORM subquery loader is known to feed in expressions
2383 which include this construct into new queries (discovered in #11173),
2384 so it has to do the right thing at compile time as well.
2385
2386 """
2387
2388 # get SQL text first
2389 sqltext = override_binds.element._compiler_dispatch(self, **kw)
2390
2391 # for a test compile that is not for caching, change binds after the
2392 # fact. note that we don't try to
2393 # swap the bindparam as we compile, because our element may be
2394 # elsewhere in the statement already (e.g. a subquery or perhaps a
2395 # CTE) and was already visited / compiled. See
2396 # test_relationship_criteria.py ->
2397 # test_selectinload_local_criteria_subquery
2398 for k in override_binds.translate:
2399 if k not in self.binds:
2400 continue
2401 bp = self.binds[k]
2402
2403 # so this would work, just change the value of bp in place.
2404 # but we dont want to mutate things outside.
2405 # bp.value = override_binds.translate[bp.key]
2406 # continue
2407
2408 # instead, need to replace bp with new_bp or otherwise accommodate
2409 # in all internal collections
2410 new_bp = bp._with_value(
2411 override_binds.translate[bp.key],
2412 maintain_key=True,
2413 required=False,
2414 )
2415
2416 name = self.bind_names[bp]
2417 self.binds[k] = self.binds[name] = new_bp
2418 self.bind_names[new_bp] = name
2419 self.bind_names.pop(bp, None)
2420
2421 if bp in self.post_compile_params:
2422 self.post_compile_params |= {new_bp}
2423 if bp in self.literal_execute_params:
2424 self.literal_execute_params |= {new_bp}
2425
2426 ckbm_tuple = self._cache_key_bind_match
2427 if ckbm_tuple:
2428 ckbm, cksm = ckbm_tuple
2429 for bp in bp._cloned_set:
2430 if bp.key in cksm:
2431 cb = cksm[bp.key]
2432 ckbm[cb].append(new_bp)
2433
2434 return sqltext
2435
2436 def visit_grouping(self, grouping, asfrom=False, **kwargs):
2437 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
2438
2439 def visit_select_statement_grouping(self, grouping, **kwargs):
2440 return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
2441
2442 def visit_label_reference(
2443 self, element, within_columns_clause=False, **kwargs
2444 ):
2445 if self.stack and self.dialect.supports_simple_order_by_label:
2446 try:
2447 compile_state = cast(
2448 "Union[SelectState, CompoundSelectState]",
2449 self.stack[-1]["compile_state"],
2450 )
2451 except KeyError as ke:
2452 raise exc.CompileError(
2453 "Can't resolve label reference for ORDER BY / "
2454 "GROUP BY / DISTINCT etc."
2455 ) from ke
2456
2457 (
2458 with_cols,
2459 only_froms,
2460 only_cols,
2461 ) = compile_state._label_resolve_dict
2462 if within_columns_clause:
2463 resolve_dict = only_froms
2464 else:
2465 resolve_dict = only_cols
2466
2467 # this can be None in the case that a _label_reference()
2468 # were subject to a replacement operation, in which case
2469 # the replacement of the Label element may have changed
2470 # to something else like a ColumnClause expression.
2471 order_by_elem = element.element._order_by_label_element
2472
2473 if (
2474 order_by_elem is not None
2475 and order_by_elem.name in resolve_dict
2476 and order_by_elem.shares_lineage(
2477 resolve_dict[order_by_elem.name]
2478 )
2479 ):
2480 kwargs["render_label_as_label"] = (
2481 element.element._order_by_label_element
2482 )
2483 return self.process(
2484 element.element,
2485 within_columns_clause=within_columns_clause,
2486 **kwargs,
2487 )
2488
2489 def visit_textual_label_reference(
2490 self, element, within_columns_clause=False, **kwargs
2491 ):
2492 if not self.stack:
2493 # compiling the element outside of the context of a SELECT
2494 return self.process(element._text_clause)
2495
2496 try:
2497 compile_state = cast(
2498 "Union[SelectState, CompoundSelectState]",
2499 self.stack[-1]["compile_state"],
2500 )
2501 except KeyError as ke:
2502 coercions._no_text_coercion(
2503 element.element,
2504 extra=(
2505 "Can't resolve label reference for ORDER BY / "
2506 "GROUP BY / DISTINCT etc."
2507 ),
2508 exc_cls=exc.CompileError,
2509 err=ke,
2510 )
2511
2512 with_cols, only_froms, only_cols = compile_state._label_resolve_dict
2513 try:
2514 if within_columns_clause:
2515 col = only_froms[element.element]
2516 else:
2517 col = with_cols[element.element]
2518 except KeyError as err:
2519 coercions._no_text_coercion(
2520 element.element,
2521 extra=(
2522 "Can't resolve label reference for ORDER BY / "
2523 "GROUP BY / DISTINCT etc."
2524 ),
2525 exc_cls=exc.CompileError,
2526 err=err,
2527 )
2528 else:
2529 kwargs["render_label_as_label"] = col
2530 return self.process(
2531 col, within_columns_clause=within_columns_clause, **kwargs
2532 )
2533
2534 def visit_label(
2535 self,
2536 label,
2537 add_to_result_map=None,
2538 within_label_clause=False,
2539 within_columns_clause=False,
2540 render_label_as_label=None,
2541 result_map_targets=(),
2542 **kw,
2543 ):
2544 # only render labels within the columns clause
2545 # or ORDER BY clause of a select. dialect-specific compilers
2546 # can modify this behavior.
2547 render_label_with_as = (
2548 within_columns_clause and not within_label_clause
2549 )
2550 render_label_only = render_label_as_label is label
2551
2552 if render_label_only or render_label_with_as:
2553 if isinstance(label.name, elements._truncated_label):
2554 labelname = self._truncated_identifier("colident", label.name)
2555 else:
2556 labelname = label.name
2557
2558 if render_label_with_as:
2559 if add_to_result_map is not None:
2560 add_to_result_map(
2561 labelname,
2562 label.name,
2563 (label, labelname) + label._alt_names + result_map_targets,
2564 label.type,
2565 )
2566 return (
2567 label.element._compiler_dispatch(
2568 self,
2569 within_columns_clause=True,
2570 within_label_clause=True,
2571 **kw,
2572 )
2573 + OPERATORS[operators.as_]
2574 + self.preparer.format_label(label, labelname)
2575 )
2576 elif render_label_only:
2577 return self.preparer.format_label(label, labelname)
2578 else:
2579 return label.element._compiler_dispatch(
2580 self, within_columns_clause=False, **kw
2581 )
2582
2583 def _fallback_column_name(self, column):
2584 raise exc.CompileError(
2585 "Cannot compile Column object until its 'name' is assigned."
2586 )
2587
2588 def visit_lambda_element(self, element, **kw):
2589 sql_element = element._resolved
2590 return self.process(sql_element, **kw)
2591
2592 def visit_column(
2593 self,
2594 column: ColumnClause[Any],
2595 add_to_result_map: Optional[_ResultMapAppender] = None,
2596 include_table: bool = True,
2597 result_map_targets: Tuple[Any, ...] = (),
2598 ambiguous_table_name_map: Optional[_AmbiguousTableNameMap] = None,
2599 **kwargs: Any,
2600 ) -> str:
2601 name = orig_name = column.name
2602 if name is None:
2603 name = self._fallback_column_name(column)
2604
2605 is_literal = column.is_literal
2606 if not is_literal and isinstance(name, elements._truncated_label):
2607 name = self._truncated_identifier("colident", name)
2608
2609 if add_to_result_map is not None:
2610 targets = (column, name, column.key) + result_map_targets
2611 if column._tq_label:
2612 targets += (column._tq_label,)
2613
2614 add_to_result_map(name, orig_name, targets, column.type)
2615
2616 if is_literal:
2617 # note we are not currently accommodating for
2618 # literal_column(quoted_name('ident', True)) here
2619 name = self.escape_literal_column(name)
2620 else:
2621 name = self.preparer.quote(name)
2622 table = column.table
2623 if table is None or not include_table or not table.named_with_column:
2624 return name
2625 else:
2626 effective_schema = self.preparer.schema_for_object(table)
2627
2628 if effective_schema:
2629 schema_prefix = (
2630 self.preparer.quote_schema(effective_schema) + "."
2631 )
2632 else:
2633 schema_prefix = ""
2634
2635 if TYPE_CHECKING:
2636 assert isinstance(table, NamedFromClause)
2637 tablename = table.name
2638
2639 if (
2640 not effective_schema
2641 and ambiguous_table_name_map
2642 and tablename in ambiguous_table_name_map
2643 ):
2644 tablename = ambiguous_table_name_map[tablename]
2645
2646 if isinstance(tablename, elements._truncated_label):
2647 tablename = self._truncated_identifier("alias", tablename)
2648
2649 return schema_prefix + self.preparer.quote(tablename) + "." + name
2650
2651 def visit_collation(self, element, **kw):
2652 return self.preparer.format_collation(element.collation)
2653
2654 def visit_fromclause(self, fromclause, **kwargs):
2655 return fromclause.name
2656
2657 def visit_index(self, index, **kwargs):
2658 return index.name
2659
2660 def visit_typeclause(self, typeclause, **kw):
2661 kw["type_expression"] = typeclause
2662 kw["identifier_preparer"] = self.preparer
2663 return self.dialect.type_compiler_instance.process(
2664 typeclause.type, **kw
2665 )
2666
2667 def post_process_text(self, text):
2668 if self.preparer._double_percents:
2669 text = text.replace("%", "%%")
2670 return text
2671
2672 def escape_literal_column(self, text):
2673 if self.preparer._double_percents:
2674 text = text.replace("%", "%%")
2675 return text
2676
2677 def visit_textclause(self, textclause, add_to_result_map=None, **kw):
2678 def do_bindparam(m):
2679 name = m.group(1)
2680 if name in textclause._bindparams:
2681 return self.process(textclause._bindparams[name], **kw)
2682 else:
2683 return self.bindparam_string(name, **kw)
2684
2685 if not self.stack:
2686 self.isplaintext = True
2687
2688 if add_to_result_map:
2689 # text() object is present in the columns clause of a
2690 # select(). Add a no-name entry to the result map so that
2691 # row[text()] produces a result
2692 add_to_result_map(None, None, (textclause,), sqltypes.NULLTYPE)
2693
2694 # un-escape any \:params
2695 return BIND_PARAMS_ESC.sub(
2696 lambda m: m.group(1),
2697 BIND_PARAMS.sub(
2698 do_bindparam, self.post_process_text(textclause.text)
2699 ),
2700 )
2701
2702 def visit_textual_select(
2703 self, taf, compound_index=None, asfrom=False, **kw
2704 ):
2705 toplevel = not self.stack
2706 entry = self._default_stack_entry if toplevel else self.stack[-1]
2707
2708 new_entry: _CompilerStackEntry = {
2709 "correlate_froms": set(),
2710 "asfrom_froms": set(),
2711 "selectable": taf,
2712 }
2713 self.stack.append(new_entry)
2714
2715 if taf._independent_ctes:
2716 self._dispatch_independent_ctes(taf, kw)
2717
2718 populate_result_map = (
2719 toplevel
2720 or (
2721 compound_index == 0
2722 and entry.get("need_result_map_for_compound", False)
2723 )
2724 or entry.get("need_result_map_for_nested", False)
2725 )
2726
2727 if populate_result_map:
2728 self._ordered_columns = self._textual_ordered_columns = (
2729 taf.positional
2730 )
2731
2732 # enable looser result column matching when the SQL text links to
2733 # Column objects by name only
2734 self._loose_column_name_matching = not taf.positional and bool(
2735 taf.column_args
2736 )
2737
2738 for c in taf.column_args:
2739 self.process(
2740 c,
2741 within_columns_clause=True,
2742 add_to_result_map=self._add_to_result_map,
2743 )
2744
2745 text = self.process(taf.element, **kw)
2746 if self.ctes:
2747 nesting_level = len(self.stack) if not toplevel else None
2748 text = self._render_cte_clause(nesting_level=nesting_level) + text
2749
2750 self.stack.pop(-1)
2751
2752 return text
2753
2754 def visit_null(self, expr: Null, **kw: Any) -> str:
2755 return "NULL"
2756
2757 def visit_true(self, expr: True_, **kw: Any) -> str:
2758 if self.dialect.supports_native_boolean:
2759 return "true"
2760 else:
2761 return "1"
2762
2763 def visit_false(self, expr: False_, **kw: Any) -> str:
2764 if self.dialect.supports_native_boolean:
2765 return "false"
2766 else:
2767 return "0"
2768
2769 def _generate_delimited_list(self, elements, separator, **kw):
2770 return separator.join(
2771 s
2772 for s in (c._compiler_dispatch(self, **kw) for c in elements)
2773 if s
2774 )
2775
2776 def _generate_delimited_and_list(self, clauses, **kw):
2777 lcc, clauses = elements.BooleanClauseList._process_clauses_for_boolean(
2778 operators.and_,
2779 elements.True_._singleton,
2780 elements.False_._singleton,
2781 clauses,
2782 )
2783 if lcc == 1:
2784 return clauses[0]._compiler_dispatch(self, **kw)
2785 else:
2786 separator = OPERATORS[operators.and_]
2787 return separator.join(
2788 s
2789 for s in (c._compiler_dispatch(self, **kw) for c in clauses)
2790 if s
2791 )
2792
2793 def visit_tuple(self, clauselist, **kw):
2794 return "(%s)" % self.visit_clauselist(clauselist, **kw)
2795
2796 def visit_clauselist(self, clauselist, **kw):
2797 sep = clauselist.operator
2798 if sep is None:
2799 sep = " "
2800 else:
2801 sep = OPERATORS[clauselist.operator]
2802
2803 return self._generate_delimited_list(clauselist.clauses, sep, **kw)
2804
2805 def visit_expression_clauselist(self, clauselist, **kw):
2806 operator_ = clauselist.operator
2807
2808 disp = self._get_operator_dispatch(
2809 operator_, "expression_clauselist", None
2810 )
2811 if disp:
2812 return disp(clauselist, operator_, **kw)
2813
2814 try:
2815 opstring = OPERATORS[operator_]
2816 except KeyError as err:
2817 raise exc.UnsupportedCompilationError(self, operator_) from err
2818 else:
2819 kw["_in_operator_expression"] = True
2820 return self._generate_delimited_list(
2821 clauselist.clauses, opstring, **kw
2822 )
2823
2824 def visit_case(self, clause, **kwargs):
2825 x = "CASE "
2826 if clause.value is not None:
2827 x += clause.value._compiler_dispatch(self, **kwargs) + " "
2828 for cond, result in clause.whens:
2829 x += (
2830 "WHEN "
2831 + cond._compiler_dispatch(self, **kwargs)
2832 + " THEN "
2833 + result._compiler_dispatch(self, **kwargs)
2834 + " "
2835 )
2836 if clause.else_ is not None:
2837 x += (
2838 "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
2839 )
2840 x += "END"
2841 return x
2842
2843 def visit_type_coerce(self, type_coerce, **kw):
2844 return type_coerce.typed_expression._compiler_dispatch(self, **kw)
2845
2846 def visit_cast(self, cast, **kwargs):
2847 type_clause = cast.typeclause._compiler_dispatch(self, **kwargs)
2848 match = re.match("(.*)( COLLATE .*)", type_clause)
2849 return "CAST(%s AS %s)%s" % (
2850 cast.clause._compiler_dispatch(self, **kwargs),
2851 match.group(1) if match else type_clause,
2852 match.group(2) if match else "",
2853 )
2854
2855 def _format_frame_clause(self, range_, **kw):
2856 return "%s AND %s" % (
2857 (
2858 "UNBOUNDED PRECEDING"
2859 if range_[0] is elements.RANGE_UNBOUNDED
2860 else (
2861 "CURRENT ROW"
2862 if range_[0] is elements.RANGE_CURRENT
2863 else (
2864 "%s PRECEDING"
2865 % (
2866 self.process(
2867 elements.literal(abs(range_[0])), **kw
2868 ),
2869 )
2870 if range_[0] < 0
2871 else "%s FOLLOWING"
2872 % (self.process(elements.literal(range_[0]), **kw),)
2873 )
2874 )
2875 ),
2876 (
2877 "UNBOUNDED FOLLOWING"
2878 if range_[1] is elements.RANGE_UNBOUNDED
2879 else (
2880 "CURRENT ROW"
2881 if range_[1] is elements.RANGE_CURRENT
2882 else (
2883 "%s PRECEDING"
2884 % (
2885 self.process(
2886 elements.literal(abs(range_[1])), **kw
2887 ),
2888 )
2889 if range_[1] < 0
2890 else "%s FOLLOWING"
2891 % (self.process(elements.literal(range_[1]), **kw),)
2892 )
2893 )
2894 ),
2895 )
2896
2897 def visit_over(self, over, **kwargs):
2898 text = over.element._compiler_dispatch(self, **kwargs)
2899 if over.range_ is not None:
2900 range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
2901 over.range_, **kwargs
2902 )
2903 elif over.rows is not None:
2904 range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
2905 over.rows, **kwargs
2906 )
2907 elif over.groups is not None:
2908 range_ = "GROUPS BETWEEN %s" % self._format_frame_clause(
2909 over.groups, **kwargs
2910 )
2911 else:
2912 range_ = None
2913
2914 return "%s OVER (%s)" % (
2915 text,
2916 " ".join(
2917 [
2918 "%s BY %s"
2919 % (word, clause._compiler_dispatch(self, **kwargs))
2920 for word, clause in (
2921 ("PARTITION", over.partition_by),
2922 ("ORDER", over.order_by),
2923 )
2924 if clause is not None and len(clause)
2925 ]
2926 + ([range_] if range_ else [])
2927 ),
2928 )
2929
2930 def visit_withingroup(self, withingroup, **kwargs):
2931 return "%s WITHIN GROUP (ORDER BY %s)" % (
2932 withingroup.element._compiler_dispatch(self, **kwargs),
2933 withingroup.order_by._compiler_dispatch(self, **kwargs),
2934 )
2935
2936 def visit_funcfilter(self, funcfilter, **kwargs):
2937 return "%s FILTER (WHERE %s)" % (
2938 funcfilter.func._compiler_dispatch(self, **kwargs),
2939 funcfilter.criterion._compiler_dispatch(self, **kwargs),
2940 )
2941
2942 def visit_extract(self, extract, **kwargs):
2943 field = self.extract_map.get(extract.field, extract.field)
2944 return "EXTRACT(%s FROM %s)" % (
2945 field,
2946 extract.expr._compiler_dispatch(self, **kwargs),
2947 )
2948
2949 def visit_scalar_function_column(self, element, **kw):
2950 compiled_fn = self.visit_function(element.fn, **kw)
2951 compiled_col = self.visit_column(element, **kw)
2952 return "(%s).%s" % (compiled_fn, compiled_col)
2953
2954 def visit_function(
2955 self,
2956 func: Function[Any],
2957 add_to_result_map: Optional[_ResultMapAppender] = None,
2958 **kwargs: Any,
2959 ) -> str:
2960 if add_to_result_map is not None:
2961 add_to_result_map(func.name, func.name, (func.name,), func.type)
2962
2963 disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
2964
2965 text: str
2966
2967 if disp:
2968 text = disp(func, **kwargs)
2969 else:
2970 name = FUNCTIONS.get(func._deannotate().__class__, None)
2971 if name:
2972 if func._has_args:
2973 name += "%(expr)s"
2974 else:
2975 name = func.name
2976 name = (
2977 self.preparer.quote(name)
2978 if self.preparer._requires_quotes_illegal_chars(name)
2979 or isinstance(name, elements.quoted_name)
2980 else name
2981 )
2982 name = name + "%(expr)s"
2983 text = ".".join(
2984 [
2985 (
2986 self.preparer.quote(tok)
2987 if self.preparer._requires_quotes_illegal_chars(tok)
2988 or isinstance(name, elements.quoted_name)
2989 else tok
2990 )
2991 for tok in func.packagenames
2992 ]
2993 + [name]
2994 ) % {"expr": self.function_argspec(func, **kwargs)}
2995
2996 if func._with_ordinality:
2997 text += " WITH ORDINALITY"
2998 return text
2999
3000 def visit_next_value_func(self, next_value, **kw):
3001 return self.visit_sequence(next_value.sequence)
3002
3003 def visit_sequence(self, sequence, **kw):
3004 raise NotImplementedError(
3005 "Dialect '%s' does not support sequence increments."
3006 % self.dialect.name
3007 )
3008
3009 def function_argspec(self, func: Function[Any], **kwargs: Any) -> str:
3010 return func.clause_expr._compiler_dispatch(self, **kwargs)
3011
3012 def visit_compound_select(
3013 self, cs, asfrom=False, compound_index=None, **kwargs
3014 ):
3015 toplevel = not self.stack
3016
3017 compile_state = cs._compile_state_factory(cs, self, **kwargs)
3018
3019 if toplevel and not self.compile_state:
3020 self.compile_state = compile_state
3021
3022 compound_stmt = compile_state.statement
3023
3024 entry = self._default_stack_entry if toplevel else self.stack[-1]
3025 need_result_map = toplevel or (
3026 not compound_index
3027 and entry.get("need_result_map_for_compound", False)
3028 )
3029
3030 # indicates there is already a CompoundSelect in play
3031 if compound_index == 0:
3032 entry["select_0"] = cs
3033
3034 self.stack.append(
3035 {
3036 "correlate_froms": entry["correlate_froms"],
3037 "asfrom_froms": entry["asfrom_froms"],
3038 "selectable": cs,
3039 "compile_state": compile_state,
3040 "need_result_map_for_compound": need_result_map,
3041 }
3042 )
3043
3044 if compound_stmt._independent_ctes:
3045 self._dispatch_independent_ctes(compound_stmt, kwargs)
3046
3047 keyword = self.compound_keywords[cs.keyword]
3048
3049 text = (" " + keyword + " ").join(
3050 (
3051 c._compiler_dispatch(
3052 self, asfrom=asfrom, compound_index=i, **kwargs
3053 )
3054 for i, c in enumerate(cs.selects)
3055 )
3056 )
3057
3058 kwargs["include_table"] = False
3059 text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
3060 text += self.order_by_clause(cs, **kwargs)
3061 if cs._has_row_limiting_clause:
3062 text += self._row_limit_clause(cs, **kwargs)
3063
3064 if self.ctes:
3065 nesting_level = len(self.stack) if not toplevel else None
3066 text = (
3067 self._render_cte_clause(
3068 nesting_level=nesting_level,
3069 include_following_stack=True,
3070 )
3071 + text
3072 )
3073
3074 self.stack.pop(-1)
3075 return text
3076
3077 def _row_limit_clause(self, cs, **kwargs):
3078 if cs._fetch_clause is not None:
3079 return self.fetch_clause(cs, **kwargs)
3080 else:
3081 return self.limit_clause(cs, **kwargs)
3082
3083 def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
3084 attrname = "visit_%s_%s%s" % (
3085 operator_.__name__,
3086 qualifier1,
3087 "_" + qualifier2 if qualifier2 else "",
3088 )
3089 return getattr(self, attrname, None)
3090
3091 def visit_unary(
3092 self, unary, add_to_result_map=None, result_map_targets=(), **kw
3093 ):
3094 if add_to_result_map is not None:
3095 result_map_targets += (unary,)
3096 kw["add_to_result_map"] = add_to_result_map
3097 kw["result_map_targets"] = result_map_targets
3098
3099 if unary.operator:
3100 if unary.modifier:
3101 raise exc.CompileError(
3102 "Unary expression does not support operator "
3103 "and modifier simultaneously"
3104 )
3105 disp = self._get_operator_dispatch(
3106 unary.operator, "unary", "operator"
3107 )
3108 if disp:
3109 return disp(unary, unary.operator, **kw)
3110 else:
3111 return self._generate_generic_unary_operator(
3112 unary, OPERATORS[unary.operator], **kw
3113 )
3114 elif unary.modifier:
3115 disp = self._get_operator_dispatch(
3116 unary.modifier, "unary", "modifier"
3117 )
3118 if disp:
3119 return disp(unary, unary.modifier, **kw)
3120 else:
3121 return self._generate_generic_unary_modifier(
3122 unary, OPERATORS[unary.modifier], **kw
3123 )
3124 else:
3125 raise exc.CompileError(
3126 "Unary expression has no operator or modifier"
3127 )
3128
3129 def visit_truediv_binary(self, binary, operator, **kw):
3130 if self.dialect.div_is_floordiv:
3131 return (
3132 self.process(binary.left, **kw)
3133 + " / "
3134 # TODO: would need a fast cast again here,
3135 # unless we want to use an implicit cast like "+ 0.0"
3136 + self.process(
3137 elements.Cast(
3138 binary.right,
3139 (
3140 binary.right.type
3141 if binary.right.type._type_affinity
3142 is sqltypes.Numeric
3143 else sqltypes.Numeric()
3144 ),
3145 ),
3146 **kw,
3147 )
3148 )
3149 else:
3150 return (
3151 self.process(binary.left, **kw)
3152 + " / "
3153 + self.process(binary.right, **kw)
3154 )
3155
3156 def visit_floordiv_binary(self, binary, operator, **kw):
3157 if (
3158 self.dialect.div_is_floordiv
3159 and binary.right.type._type_affinity is sqltypes.Integer
3160 ):
3161 return (
3162 self.process(binary.left, **kw)
3163 + " / "
3164 + self.process(binary.right, **kw)
3165 )
3166 else:
3167 return "FLOOR(%s)" % (
3168 self.process(binary.left, **kw)
3169 + " / "
3170 + self.process(binary.right, **kw)
3171 )
3172
3173 def visit_is_true_unary_operator(self, element, operator, **kw):
3174 if (
3175 element._is_implicitly_boolean
3176 or self.dialect.supports_native_boolean
3177 ):
3178 return self.process(element.element, **kw)
3179 else:
3180 return "%s = 1" % self.process(element.element, **kw)
3181
3182 def visit_is_false_unary_operator(self, element, operator, **kw):
3183 if (
3184 element._is_implicitly_boolean
3185 or self.dialect.supports_native_boolean
3186 ):
3187 return "NOT %s" % self.process(element.element, **kw)
3188 else:
3189 return "%s = 0" % self.process(element.element, **kw)
3190
3191 def visit_not_match_op_binary(self, binary, operator, **kw):
3192 return "NOT %s" % self.visit_binary(
3193 binary, override_operator=operators.match_op
3194 )
3195
3196 def visit_not_in_op_binary(self, binary, operator, **kw):
3197 # The brackets are required in the NOT IN operation because the empty
3198 # case is handled using the form "(col NOT IN (null) OR 1 = 1)".
3199 # The presence of the OR makes the brackets required.
3200 return "(%s)" % self._generate_generic_binary(
3201 binary, OPERATORS[operator], **kw
3202 )
3203
3204 def visit_empty_set_op_expr(self, type_, expand_op, **kw):
3205 if expand_op is operators.not_in_op:
3206 if len(type_) > 1:
3207 return "(%s)) OR (1 = 1" % (
3208 ", ".join("NULL" for element in type_)
3209 )
3210 else:
3211 return "NULL) OR (1 = 1"
3212 elif expand_op is operators.in_op:
3213 if len(type_) > 1:
3214 return "(%s)) AND (1 != 1" % (
3215 ", ".join("NULL" for element in type_)
3216 )
3217 else:
3218 return "NULL) AND (1 != 1"
3219 else:
3220 return self.visit_empty_set_expr(type_)
3221
3222 def visit_empty_set_expr(self, element_types, **kw):
3223 raise NotImplementedError(
3224 "Dialect '%s' does not support empty set expression."
3225 % self.dialect.name
3226 )
3227
3228 def _literal_execute_expanding_parameter_literal_binds(
3229 self, parameter, values, bind_expression_template=None
3230 ):
3231 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
3232
3233 if not values:
3234 # empty IN expression. note we don't need to use
3235 # bind_expression_template here because there are no
3236 # expressions to render.
3237
3238 if typ_dialect_impl._is_tuple_type:
3239 replacement_expression = (
3240 "VALUES " if self.dialect.tuple_in_values else ""
3241 ) + self.visit_empty_set_op_expr(
3242 parameter.type.types, parameter.expand_op
3243 )
3244
3245 else:
3246 replacement_expression = self.visit_empty_set_op_expr(
3247 [parameter.type], parameter.expand_op
3248 )
3249
3250 elif typ_dialect_impl._is_tuple_type or (
3251 typ_dialect_impl._isnull
3252 and isinstance(values[0], collections_abc.Sequence)
3253 and not isinstance(values[0], (str, bytes))
3254 ):
3255 if typ_dialect_impl._has_bind_expression:
3256 raise NotImplementedError(
3257 "bind_expression() on TupleType not supported with "
3258 "literal_binds"
3259 )
3260
3261 replacement_expression = (
3262 "VALUES " if self.dialect.tuple_in_values else ""
3263 ) + ", ".join(
3264 "(%s)"
3265 % (
3266 ", ".join(
3267 self.render_literal_value(value, param_type)
3268 for value, param_type in zip(
3269 tuple_element, parameter.type.types
3270 )
3271 )
3272 )
3273 for i, tuple_element in enumerate(values)
3274 )
3275 else:
3276 if bind_expression_template:
3277 post_compile_pattern = self._post_compile_pattern
3278 m = post_compile_pattern.search(bind_expression_template)
3279 assert m and m.group(
3280 2
3281 ), "unexpected format for expanding parameter"
3282
3283 tok = m.group(2).split("~~")
3284 be_left, be_right = tok[1], tok[3]
3285 replacement_expression = ", ".join(
3286 "%s%s%s"
3287 % (
3288 be_left,
3289 self.render_literal_value(value, parameter.type),
3290 be_right,
3291 )
3292 for value in values
3293 )
3294 else:
3295 replacement_expression = ", ".join(
3296 self.render_literal_value(value, parameter.type)
3297 for value in values
3298 )
3299
3300 return (), replacement_expression
3301
3302 def _literal_execute_expanding_parameter(self, name, parameter, values):
3303 if parameter.literal_execute:
3304 return self._literal_execute_expanding_parameter_literal_binds(
3305 parameter, values
3306 )
3307
3308 dialect = self.dialect
3309 typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect)
3310
3311 if self._numeric_binds:
3312 bind_template = self.compilation_bindtemplate
3313 else:
3314 bind_template = self.bindtemplate
3315
3316 if (
3317 self.dialect._bind_typing_render_casts
3318 and typ_dialect_impl.render_bind_cast
3319 ):
3320
3321 def _render_bindtemplate(name):
3322 return self.render_bind_cast(
3323 parameter.type,
3324 typ_dialect_impl,
3325 bind_template % {"name": name},
3326 )
3327
3328 else:
3329
3330 def _render_bindtemplate(name):
3331 return bind_template % {"name": name}
3332
3333 if not values:
3334 to_update = []
3335 if typ_dialect_impl._is_tuple_type:
3336 replacement_expression = self.visit_empty_set_op_expr(
3337 parameter.type.types, parameter.expand_op
3338 )
3339 else:
3340 replacement_expression = self.visit_empty_set_op_expr(
3341 [parameter.type], parameter.expand_op
3342 )
3343
3344 elif typ_dialect_impl._is_tuple_type or (
3345 typ_dialect_impl._isnull
3346 and isinstance(values[0], collections_abc.Sequence)
3347 and not isinstance(values[0], (str, bytes))
3348 ):
3349 assert not typ_dialect_impl._is_array
3350 to_update = [
3351 ("%s_%s_%s" % (name, i, j), value)
3352 for i, tuple_element in enumerate(values, 1)
3353 for j, value in enumerate(tuple_element, 1)
3354 ]
3355
3356 replacement_expression = (
3357 "VALUES " if dialect.tuple_in_values else ""
3358 ) + ", ".join(
3359 "(%s)"
3360 % (
3361 ", ".join(
3362 _render_bindtemplate(
3363 to_update[i * len(tuple_element) + j][0]
3364 )
3365 for j, value in enumerate(tuple_element)
3366 )
3367 )
3368 for i, tuple_element in enumerate(values)
3369 )
3370 else:
3371 to_update = [
3372 ("%s_%s" % (name, i), value)
3373 for i, value in enumerate(values, 1)
3374 ]
3375 replacement_expression = ", ".join(
3376 _render_bindtemplate(key) for key, value in to_update
3377 )
3378
3379 return to_update, replacement_expression
3380
3381 def visit_binary(
3382 self,
3383 binary,
3384 override_operator=None,
3385 eager_grouping=False,
3386 from_linter=None,
3387 lateral_from_linter=None,
3388 **kw,
3389 ):
3390 if from_linter and operators.is_comparison(binary.operator):
3391 if lateral_from_linter is not None:
3392 enclosing_lateral = kw["enclosing_lateral"]
3393 lateral_from_linter.edges.update(
3394 itertools.product(
3395 _de_clone(
3396 binary.left._from_objects + [enclosing_lateral]
3397 ),
3398 _de_clone(
3399 binary.right._from_objects + [enclosing_lateral]
3400 ),
3401 )
3402 )
3403 else:
3404 from_linter.edges.update(
3405 itertools.product(
3406 _de_clone(binary.left._from_objects),
3407 _de_clone(binary.right._from_objects),
3408 )
3409 )
3410
3411 # don't allow "? = ?" to render
3412 if (
3413 self.ansi_bind_rules
3414 and isinstance(binary.left, elements.BindParameter)
3415 and isinstance(binary.right, elements.BindParameter)
3416 ):
3417 kw["literal_execute"] = True
3418
3419 operator_ = override_operator or binary.operator
3420 disp = self._get_operator_dispatch(operator_, "binary", None)
3421 if disp:
3422 return disp(binary, operator_, **kw)
3423 else:
3424 try:
3425 opstring = OPERATORS[operator_]
3426 except KeyError as err:
3427 raise exc.UnsupportedCompilationError(self, operator_) from err
3428 else:
3429 return self._generate_generic_binary(
3430 binary,
3431 opstring,
3432 from_linter=from_linter,
3433 lateral_from_linter=lateral_from_linter,
3434 **kw,
3435 )
3436
3437 def visit_function_as_comparison_op_binary(self, element, operator, **kw):
3438 return self.process(element.sql_function, **kw)
3439
3440 def visit_mod_binary(self, binary, operator, **kw):
3441 if self.preparer._double_percents:
3442 return (
3443 self.process(binary.left, **kw)
3444 + " %% "
3445 + self.process(binary.right, **kw)
3446 )
3447 else:
3448 return (
3449 self.process(binary.left, **kw)
3450 + " % "
3451 + self.process(binary.right, **kw)
3452 )
3453
3454 def visit_custom_op_binary(self, element, operator, **kw):
3455 kw["eager_grouping"] = operator.eager_grouping
3456 return self._generate_generic_binary(
3457 element,
3458 " " + self.escape_literal_column(operator.opstring) + " ",
3459 **kw,
3460 )
3461
3462 def visit_custom_op_unary_operator(self, element, operator, **kw):
3463 return self._generate_generic_unary_operator(
3464 element, self.escape_literal_column(operator.opstring) + " ", **kw
3465 )
3466
3467 def visit_custom_op_unary_modifier(self, element, operator, **kw):
3468 return self._generate_generic_unary_modifier(
3469 element, " " + self.escape_literal_column(operator.opstring), **kw
3470 )
3471
3472 def _generate_generic_binary(
3473 self,
3474 binary: BinaryExpression[Any],
3475 opstring: str,
3476 eager_grouping: bool = False,
3477 **kw: Any,
3478 ) -> str:
3479 _in_operator_expression = kw.get("_in_operator_expression", False)
3480
3481 kw["_in_operator_expression"] = True
3482 kw["_binary_op"] = binary.operator
3483 text = (
3484 binary.left._compiler_dispatch(
3485 self, eager_grouping=eager_grouping, **kw
3486 )
3487 + opstring
3488 + binary.right._compiler_dispatch(
3489 self, eager_grouping=eager_grouping, **kw
3490 )
3491 )
3492
3493 if _in_operator_expression and eager_grouping:
3494 text = "(%s)" % text
3495 return text
3496
3497 def _generate_generic_unary_operator(self, unary, opstring, **kw):
3498 return opstring + unary.element._compiler_dispatch(self, **kw)
3499
3500 def _generate_generic_unary_modifier(self, unary, opstring, **kw):
3501 return unary.element._compiler_dispatch(self, **kw) + opstring
3502
3503 @util.memoized_property
3504 def _like_percent_literal(self):
3505 return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
3506
3507 def visit_ilike_case_insensitive_operand(self, element, **kw):
3508 return f"lower({element.element._compiler_dispatch(self, **kw)})"
3509
3510 def visit_contains_op_binary(self, binary, operator, **kw):
3511 binary = binary._clone()
3512 percent = self._like_percent_literal
3513 binary.right = percent.concat(binary.right).concat(percent)
3514 return self.visit_like_op_binary(binary, operator, **kw)
3515
3516 def visit_not_contains_op_binary(self, binary, operator, **kw):
3517 binary = binary._clone()
3518 percent = self._like_percent_literal
3519 binary.right = percent.concat(binary.right).concat(percent)
3520 return self.visit_not_like_op_binary(binary, operator, **kw)
3521
3522 def visit_icontains_op_binary(self, binary, operator, **kw):
3523 binary = binary._clone()
3524 percent = self._like_percent_literal
3525 binary.left = ilike_case_insensitive(binary.left)
3526 binary.right = percent.concat(
3527 ilike_case_insensitive(binary.right)
3528 ).concat(percent)
3529 return self.visit_ilike_op_binary(binary, operator, **kw)
3530
3531 def visit_not_icontains_op_binary(self, binary, operator, **kw):
3532 binary = binary._clone()
3533 percent = self._like_percent_literal
3534 binary.left = ilike_case_insensitive(binary.left)
3535 binary.right = percent.concat(
3536 ilike_case_insensitive(binary.right)
3537 ).concat(percent)
3538 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3539
3540 def visit_startswith_op_binary(self, binary, operator, **kw):
3541 binary = binary._clone()
3542 percent = self._like_percent_literal
3543 binary.right = percent._rconcat(binary.right)
3544 return self.visit_like_op_binary(binary, operator, **kw)
3545
3546 def visit_not_startswith_op_binary(self, binary, operator, **kw):
3547 binary = binary._clone()
3548 percent = self._like_percent_literal
3549 binary.right = percent._rconcat(binary.right)
3550 return self.visit_not_like_op_binary(binary, operator, **kw)
3551
3552 def visit_istartswith_op_binary(self, binary, operator, **kw):
3553 binary = binary._clone()
3554 percent = self._like_percent_literal
3555 binary.left = ilike_case_insensitive(binary.left)
3556 binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
3557 return self.visit_ilike_op_binary(binary, operator, **kw)
3558
3559 def visit_not_istartswith_op_binary(self, binary, operator, **kw):
3560 binary = binary._clone()
3561 percent = self._like_percent_literal
3562 binary.left = ilike_case_insensitive(binary.left)
3563 binary.right = percent._rconcat(ilike_case_insensitive(binary.right))
3564 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3565
3566 def visit_endswith_op_binary(self, binary, operator, **kw):
3567 binary = binary._clone()
3568 percent = self._like_percent_literal
3569 binary.right = percent.concat(binary.right)
3570 return self.visit_like_op_binary(binary, operator, **kw)
3571
3572 def visit_not_endswith_op_binary(self, binary, operator, **kw):
3573 binary = binary._clone()
3574 percent = self._like_percent_literal
3575 binary.right = percent.concat(binary.right)
3576 return self.visit_not_like_op_binary(binary, operator, **kw)
3577
3578 def visit_iendswith_op_binary(self, binary, operator, **kw):
3579 binary = binary._clone()
3580 percent = self._like_percent_literal
3581 binary.left = ilike_case_insensitive(binary.left)
3582 binary.right = percent.concat(ilike_case_insensitive(binary.right))
3583 return self.visit_ilike_op_binary(binary, operator, **kw)
3584
3585 def visit_not_iendswith_op_binary(self, binary, operator, **kw):
3586 binary = binary._clone()
3587 percent = self._like_percent_literal
3588 binary.left = ilike_case_insensitive(binary.left)
3589 binary.right = percent.concat(ilike_case_insensitive(binary.right))
3590 return self.visit_not_ilike_op_binary(binary, operator, **kw)
3591
3592 def visit_like_op_binary(self, binary, operator, **kw):
3593 escape = binary.modifiers.get("escape", None)
3594
3595 return "%s LIKE %s" % (
3596 binary.left._compiler_dispatch(self, **kw),
3597 binary.right._compiler_dispatch(self, **kw),
3598 ) + (
3599 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
3600 if escape is not None
3601 else ""
3602 )
3603
3604 def visit_not_like_op_binary(self, binary, operator, **kw):
3605 escape = binary.modifiers.get("escape", None)
3606 return "%s NOT LIKE %s" % (
3607 binary.left._compiler_dispatch(self, **kw),
3608 binary.right._compiler_dispatch(self, **kw),
3609 ) + (
3610 " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
3611 if escape is not None
3612 else ""
3613 )
3614
3615 def visit_ilike_op_binary(self, binary, operator, **kw):
3616 if operator is operators.ilike_op:
3617 binary = binary._clone()
3618 binary.left = ilike_case_insensitive(binary.left)
3619 binary.right = ilike_case_insensitive(binary.right)
3620 # else we assume ilower() has been applied
3621
3622 return self.visit_like_op_binary(binary, operator, **kw)
3623
3624 def visit_not_ilike_op_binary(self, binary, operator, **kw):
3625 if operator is operators.not_ilike_op:
3626 binary = binary._clone()
3627 binary.left = ilike_case_insensitive(binary.left)
3628 binary.right = ilike_case_insensitive(binary.right)
3629 # else we assume ilower() has been applied
3630
3631 return self.visit_not_like_op_binary(binary, operator, **kw)
3632
3633 def visit_between_op_binary(self, binary, operator, **kw):
3634 symmetric = binary.modifiers.get("symmetric", False)
3635 return self._generate_generic_binary(
3636 binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
3637 )
3638
3639 def visit_not_between_op_binary(self, binary, operator, **kw):
3640 symmetric = binary.modifiers.get("symmetric", False)
3641 return self._generate_generic_binary(
3642 binary,
3643 " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
3644 **kw,
3645 )
3646
3647 def visit_regexp_match_op_binary(
3648 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3649 ) -> str:
3650 raise exc.CompileError(
3651 "%s dialect does not support regular expressions"
3652 % self.dialect.name
3653 )
3654
3655 def visit_not_regexp_match_op_binary(
3656 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3657 ) -> str:
3658 raise exc.CompileError(
3659 "%s dialect does not support regular expressions"
3660 % self.dialect.name
3661 )
3662
3663 def visit_regexp_replace_op_binary(
3664 self, binary: BinaryExpression[Any], operator: Any, **kw: Any
3665 ) -> str:
3666 raise exc.CompileError(
3667 "%s dialect does not support regular expression replacements"
3668 % self.dialect.name
3669 )
3670
3671 def visit_bindparam(
3672 self,
3673 bindparam,
3674 within_columns_clause=False,
3675 literal_binds=False,
3676 skip_bind_expression=False,
3677 literal_execute=False,
3678 render_postcompile=False,
3679 **kwargs,
3680 ):
3681
3682 if not skip_bind_expression:
3683 impl = bindparam.type.dialect_impl(self.dialect)
3684 if impl._has_bind_expression:
3685 bind_expression = impl.bind_expression(bindparam)
3686 wrapped = self.process(
3687 bind_expression,
3688 skip_bind_expression=True,
3689 within_columns_clause=within_columns_clause,
3690 literal_binds=literal_binds and not bindparam.expanding,
3691 literal_execute=literal_execute,
3692 render_postcompile=render_postcompile,
3693 **kwargs,
3694 )
3695 if bindparam.expanding:
3696 # for postcompile w/ expanding, move the "wrapped" part
3697 # of this into the inside
3698
3699 m = re.match(
3700 r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
3701 )
3702 assert m, "unexpected format for expanding parameter"
3703 wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
3704 m.group(2),
3705 m.group(1),
3706 m.group(3),
3707 )
3708
3709 if literal_binds:
3710 ret = self.render_literal_bindparam(
3711 bindparam,
3712 within_columns_clause=True,
3713 bind_expression_template=wrapped,
3714 **kwargs,
3715 )
3716 return "(%s)" % ret
3717
3718 return wrapped
3719
3720 if not literal_binds:
3721 literal_execute = (
3722 literal_execute
3723 or bindparam.literal_execute
3724 or (within_columns_clause and self.ansi_bind_rules)
3725 )
3726 post_compile = literal_execute or bindparam.expanding
3727 else:
3728 post_compile = False
3729
3730 if literal_binds:
3731 ret = self.render_literal_bindparam(
3732 bindparam, within_columns_clause=True, **kwargs
3733 )
3734 if bindparam.expanding:
3735 ret = "(%s)" % ret
3736 return ret
3737
3738 name = self._truncate_bindparam(bindparam)
3739
3740 if name in self.binds:
3741 existing = self.binds[name]
3742 if existing is not bindparam:
3743 if (
3744 (existing.unique or bindparam.unique)
3745 and not existing.proxy_set.intersection(
3746 bindparam.proxy_set
3747 )
3748 and not existing._cloned_set.intersection(
3749 bindparam._cloned_set
3750 )
3751 ):
3752 raise exc.CompileError(
3753 "Bind parameter '%s' conflicts with "
3754 "unique bind parameter of the same name" % name
3755 )
3756 elif existing.expanding != bindparam.expanding:
3757 raise exc.CompileError(
3758 "Can't reuse bound parameter name '%s' in both "
3759 "'expanding' (e.g. within an IN expression) and "
3760 "non-expanding contexts. If this parameter is to "
3761 "receive a list/array value, set 'expanding=True' on "
3762 "it for expressions that aren't IN, otherwise use "
3763 "a different parameter name." % (name,)
3764 )
3765 elif existing._is_crud or bindparam._is_crud:
3766 if existing._is_crud and bindparam._is_crud:
3767 # TODO: this condition is not well understood.
3768 # see tests in test/sql/test_update.py
3769 raise exc.CompileError(
3770 "Encountered unsupported case when compiling an "
3771 "INSERT or UPDATE statement. If this is a "
3772 "multi-table "
3773 "UPDATE statement, please provide string-named "
3774 "arguments to the "
3775 "values() method with distinct names; support for "
3776 "multi-table UPDATE statements that "
3777 "target multiple tables for UPDATE is very "
3778 "limited",
3779 )
3780 else:
3781 raise exc.CompileError(
3782 f"bindparam() name '{bindparam.key}' is reserved "
3783 "for automatic usage in the VALUES or SET "
3784 "clause of this "
3785 "insert/update statement. Please use a "
3786 "name other than column name when using "
3787 "bindparam() "
3788 "with insert() or update() (for example, "
3789 f"'b_{bindparam.key}')."
3790 )
3791
3792 self.binds[bindparam.key] = self.binds[name] = bindparam
3793
3794 # if we are given a cache key that we're going to match against,
3795 # relate the bindparam here to one that is most likely present
3796 # in the "extracted params" portion of the cache key. this is used
3797 # to set up a positional mapping that is used to determine the
3798 # correct parameters for a subsequent use of this compiled with
3799 # a different set of parameter values. here, we accommodate for
3800 # parameters that may have been cloned both before and after the cache
3801 # key was been generated.
3802 ckbm_tuple = self._cache_key_bind_match
3803
3804 if ckbm_tuple:
3805 ckbm, cksm = ckbm_tuple
3806 for bp in bindparam._cloned_set:
3807 if bp.key in cksm:
3808 cb = cksm[bp.key]
3809 ckbm[cb].append(bindparam)
3810
3811 if bindparam.isoutparam:
3812 self.has_out_parameters = True
3813
3814 if post_compile:
3815 if render_postcompile:
3816 self._render_postcompile = True
3817
3818 if literal_execute:
3819 self.literal_execute_params |= {bindparam}
3820 else:
3821 self.post_compile_params |= {bindparam}
3822
3823 ret = self.bindparam_string(
3824 name,
3825 post_compile=post_compile,
3826 expanding=bindparam.expanding,
3827 bindparam_type=bindparam.type,
3828 **kwargs,
3829 )
3830
3831 if bindparam.expanding:
3832 ret = "(%s)" % ret
3833
3834 return ret
3835
3836 def render_bind_cast(self, type_, dbapi_type, sqltext):
3837 raise NotImplementedError()
3838
3839 def render_literal_bindparam(
3840 self,
3841 bindparam,
3842 render_literal_value=NO_ARG,
3843 bind_expression_template=None,
3844 **kw,
3845 ):
3846 if render_literal_value is not NO_ARG:
3847 value = render_literal_value
3848 else:
3849 if bindparam.value is None and bindparam.callable is None:
3850 op = kw.get("_binary_op", None)
3851 if op and op not in (operators.is_, operators.is_not):
3852 util.warn_limited(
3853 "Bound parameter '%s' rendering literal NULL in a SQL "
3854 "expression; comparisons to NULL should not use "
3855 "operators outside of 'is' or 'is not'",
3856 (bindparam.key,),
3857 )
3858 return self.process(sqltypes.NULLTYPE, **kw)
3859 value = bindparam.effective_value
3860
3861 if bindparam.expanding:
3862 leep = self._literal_execute_expanding_parameter_literal_binds
3863 to_update, replacement_expr = leep(
3864 bindparam,
3865 value,
3866 bind_expression_template=bind_expression_template,
3867 )
3868 return replacement_expr
3869 else:
3870 return self.render_literal_value(value, bindparam.type)
3871
3872 def render_literal_value(
3873 self, value: Any, type_: sqltypes.TypeEngine[Any]
3874 ) -> str:
3875 """Render the value of a bind parameter as a quoted literal.
3876
3877 This is used for statement sections that do not accept bind parameters
3878 on the target driver/database.
3879
3880 This should be implemented by subclasses using the quoting services
3881 of the DBAPI.
3882
3883 """
3884
3885 if value is None and not type_.should_evaluate_none:
3886 # issue #10535 - handle NULL in the compiler without placing
3887 # this onto each type, except for "evaluate None" types
3888 # (e.g. JSON)
3889 return self.process(elements.Null._instance())
3890
3891 processor = type_._cached_literal_processor(self.dialect)
3892 if processor:
3893 try:
3894 return processor(value)
3895 except Exception as e:
3896 raise exc.CompileError(
3897 f"Could not render literal value "
3898 f'"{sql_util._repr_single_value(value)}" '
3899 f"with datatype "
3900 f"{type_}; see parent stack trace for "
3901 "more detail."
3902 ) from e
3903
3904 else:
3905 raise exc.CompileError(
3906 f"No literal value renderer is available for literal value "
3907 f'"{sql_util._repr_single_value(value)}" '
3908 f"with datatype {type_}"
3909 )
3910
3911 def _truncate_bindparam(self, bindparam):
3912 if bindparam in self.bind_names:
3913 return self.bind_names[bindparam]
3914
3915 bind_name = bindparam.key
3916 if isinstance(bind_name, elements._truncated_label):
3917 bind_name = self._truncated_identifier("bindparam", bind_name)
3918
3919 # add to bind_names for translation
3920 self.bind_names[bindparam] = bind_name
3921
3922 return bind_name
3923
3924 def _truncated_identifier(
3925 self, ident_class: str, name: _truncated_label
3926 ) -> str:
3927 if (ident_class, name) in self.truncated_names:
3928 return self.truncated_names[(ident_class, name)]
3929
3930 anonname = name.apply_map(self.anon_map)
3931
3932 if len(anonname) > self.label_length - 6:
3933 counter = self._truncated_counters.get(ident_class, 1)
3934 truncname = (
3935 anonname[0 : max(self.label_length - 6, 0)]
3936 + "_"
3937 + hex(counter)[2:]
3938 )
3939 self._truncated_counters[ident_class] = counter + 1
3940 else:
3941 truncname = anonname
3942 self.truncated_names[(ident_class, name)] = truncname
3943 return truncname
3944
3945 def _anonymize(self, name: str) -> str:
3946 return name % self.anon_map
3947
3948 def bindparam_string(
3949 self,
3950 name: str,
3951 post_compile: bool = False,
3952 expanding: bool = False,
3953 escaped_from: Optional[str] = None,
3954 bindparam_type: Optional[TypeEngine[Any]] = None,
3955 accumulate_bind_names: Optional[Set[str]] = None,
3956 visited_bindparam: Optional[List[str]] = None,
3957 **kw: Any,
3958 ) -> str:
3959 # TODO: accumulate_bind_names is passed by crud.py to gather
3960 # names on a per-value basis, visited_bindparam is passed by
3961 # visit_insert() to collect all parameters in the statement.
3962 # see if this gathering can be simplified somehow
3963 if accumulate_bind_names is not None:
3964 accumulate_bind_names.add(name)
3965 if visited_bindparam is not None:
3966 visited_bindparam.append(name)
3967
3968 if not escaped_from:
3969 if self._bind_translate_re.search(name):
3970 # not quite the translate use case as we want to
3971 # also get a quick boolean if we even found
3972 # unusual characters in the name
3973 new_name = self._bind_translate_re.sub(
3974 lambda m: self._bind_translate_chars[m.group(0)],
3975 name,
3976 )
3977 escaped_from = name
3978 name = new_name
3979
3980 if escaped_from:
3981 self.escaped_bind_names = self.escaped_bind_names.union(
3982 {escaped_from: name}
3983 )
3984 if post_compile:
3985 ret = "__[POSTCOMPILE_%s]" % name
3986 if expanding:
3987 # for expanding, bound parameters or literal values will be
3988 # rendered per item
3989 return ret
3990
3991 # otherwise, for non-expanding "literal execute", apply
3992 # bind casts as determined by the datatype
3993 if bindparam_type is not None:
3994 type_impl = bindparam_type._unwrapped_dialect_impl(
3995 self.dialect
3996 )
3997 if type_impl.render_literal_cast:
3998 ret = self.render_bind_cast(bindparam_type, type_impl, ret)
3999 return ret
4000 elif self.state is CompilerState.COMPILING:
4001 ret = self.compilation_bindtemplate % {"name": name}
4002 else:
4003 ret = self.bindtemplate % {"name": name}
4004
4005 if (
4006 bindparam_type is not None
4007 and self.dialect._bind_typing_render_casts
4008 ):
4009 type_impl = bindparam_type._unwrapped_dialect_impl(self.dialect)
4010 if type_impl.render_bind_cast:
4011 ret = self.render_bind_cast(bindparam_type, type_impl, ret)
4012
4013 return ret
4014
4015 def _dispatch_independent_ctes(self, stmt, kw):
4016 local_kw = kw.copy()
4017 local_kw.pop("cte_opts", None)
4018 for cte, opt in zip(
4019 stmt._independent_ctes, stmt._independent_ctes_opts
4020 ):
4021 cte._compiler_dispatch(self, cte_opts=opt, **local_kw)
4022
4023 def visit_cte(
4024 self,
4025 cte: CTE,
4026 asfrom: bool = False,
4027 ashint: bool = False,
4028 fromhints: Optional[_FromHintsType] = None,
4029 visiting_cte: Optional[CTE] = None,
4030 from_linter: Optional[FromLinter] = None,
4031 cte_opts: selectable._CTEOpts = selectable._CTEOpts(False),
4032 **kwargs: Any,
4033 ) -> Optional[str]:
4034 self_ctes = self._init_cte_state()
4035 assert self_ctes is self.ctes
4036
4037 kwargs["visiting_cte"] = cte
4038
4039 cte_name = cte.name
4040
4041 if isinstance(cte_name, elements._truncated_label):
4042 cte_name = self._truncated_identifier("alias", cte_name)
4043
4044 is_new_cte = True
4045 embedded_in_current_named_cte = False
4046
4047 _reference_cte = cte._get_reference_cte()
4048
4049 nesting = cte.nesting or cte_opts.nesting
4050
4051 # check for CTE already encountered
4052 if _reference_cte in self.level_name_by_cte:
4053 cte_level, _, existing_cte_opts = self.level_name_by_cte[
4054 _reference_cte
4055 ]
4056 assert _ == cte_name
4057
4058 cte_level_name = (cte_level, cte_name)
4059 existing_cte = self.ctes_by_level_name[cte_level_name]
4060
4061 # check if we are receiving it here with a specific
4062 # "nest_here" location; if so, move it to this location
4063
4064 if cte_opts.nesting:
4065 if existing_cte_opts.nesting:
4066 raise exc.CompileError(
4067 "CTE is stated as 'nest_here' in "
4068 "more than one location"
4069 )
4070
4071 old_level_name = (cte_level, cte_name)
4072 cte_level = len(self.stack) if nesting else 1
4073 cte_level_name = new_level_name = (cte_level, cte_name)
4074
4075 del self.ctes_by_level_name[old_level_name]
4076 self.ctes_by_level_name[new_level_name] = existing_cte
4077 self.level_name_by_cte[_reference_cte] = new_level_name + (
4078 cte_opts,
4079 )
4080
4081 else:
4082 cte_level = len(self.stack) if nesting else 1
4083 cte_level_name = (cte_level, cte_name)
4084
4085 if cte_level_name in self.ctes_by_level_name:
4086 existing_cte = self.ctes_by_level_name[cte_level_name]
4087 else:
4088 existing_cte = None
4089
4090 if existing_cte is not None:
4091 embedded_in_current_named_cte = visiting_cte is existing_cte
4092
4093 # we've generated a same-named CTE that we are enclosed in,
4094 # or this is the same CTE. just return the name.
4095 if cte is existing_cte._restates or cte is existing_cte:
4096 is_new_cte = False
4097 elif existing_cte is cte._restates:
4098 # we've generated a same-named CTE that is
4099 # enclosed in us - we take precedence, so
4100 # discard the text for the "inner".
4101 del self_ctes[existing_cte]
4102
4103 existing_cte_reference_cte = existing_cte._get_reference_cte()
4104
4105 assert existing_cte_reference_cte is _reference_cte
4106 assert existing_cte_reference_cte is existing_cte
4107
4108 del self.level_name_by_cte[existing_cte_reference_cte]
4109 else:
4110 if (
4111 # if the two CTEs have the same hash, which we expect
4112 # here means that one/both is an annotated of the other
4113 (hash(cte) == hash(existing_cte))
4114 # or...
4115 or (
4116 (
4117 # if they are clones, i.e. they came from the ORM
4118 # or some other visit method
4119 cte._is_clone_of is not None
4120 or existing_cte._is_clone_of is not None
4121 )
4122 # and are deep-copy identical
4123 and cte.compare(existing_cte)
4124 )
4125 ):
4126 # then consider these two CTEs the same
4127 is_new_cte = False
4128 else:
4129 # otherwise these are two CTEs that either will render
4130 # differently, or were indicated separately by the user,
4131 # with the same name
4132 raise exc.CompileError(
4133 "Multiple, unrelated CTEs found with "
4134 "the same name: %r" % cte_name
4135 )
4136
4137 if not asfrom and not is_new_cte:
4138 return None
4139
4140 if cte._cte_alias is not None:
4141 pre_alias_cte = cte._cte_alias
4142 cte_pre_alias_name = cte._cte_alias.name
4143 if isinstance(cte_pre_alias_name, elements._truncated_label):
4144 cte_pre_alias_name = self._truncated_identifier(
4145 "alias", cte_pre_alias_name
4146 )
4147 else:
4148 pre_alias_cte = cte
4149 cte_pre_alias_name = None
4150
4151 if is_new_cte:
4152 self.ctes_by_level_name[cte_level_name] = cte
4153 self.level_name_by_cte[_reference_cte] = cte_level_name + (
4154 cte_opts,
4155 )
4156
4157 if pre_alias_cte not in self.ctes:
4158 self.visit_cte(pre_alias_cte, **kwargs)
4159
4160 if not cte_pre_alias_name and cte not in self_ctes:
4161 if cte.recursive:
4162 self.ctes_recursive = True
4163 text = self.preparer.format_alias(cte, cte_name)
4164 if cte.recursive:
4165 col_source = cte.element
4166
4167 # TODO: can we get at the .columns_plus_names collection
4168 # that is already (or will be?) generated for the SELECT
4169 # rather than calling twice?
4170 recur_cols = [
4171 # TODO: proxy_name is not technically safe,
4172 # see test_cte->
4173 # test_with_recursive_no_name_currently_buggy. not
4174 # clear what should be done with such a case
4175 fallback_label_name or proxy_name
4176 for (
4177 _,
4178 proxy_name,
4179 fallback_label_name,
4180 c,
4181 repeated,
4182 ) in (col_source._generate_columns_plus_names(True))
4183 if not repeated
4184 ]
4185
4186 text += "(%s)" % (
4187 ", ".join(
4188 self.preparer.format_label_name(
4189 ident, anon_map=self.anon_map
4190 )
4191 for ident in recur_cols
4192 )
4193 )
4194
4195 assert kwargs.get("subquery", False) is False
4196
4197 if not self.stack:
4198 # toplevel, this is a stringify of the
4199 # cte directly. just compile the inner
4200 # the way alias() does.
4201 return cte.element._compiler_dispatch(
4202 self, asfrom=asfrom, **kwargs
4203 )
4204 else:
4205 prefixes = self._generate_prefixes(
4206 cte, cte._prefixes, **kwargs
4207 )
4208 inner = cte.element._compiler_dispatch(
4209 self, asfrom=True, **kwargs
4210 )
4211
4212 text += " AS %s\n(%s)" % (prefixes, inner)
4213
4214 if cte._suffixes:
4215 text += " " + self._generate_prefixes(
4216 cte, cte._suffixes, **kwargs
4217 )
4218
4219 self_ctes[cte] = text
4220
4221 if asfrom:
4222 if from_linter:
4223 from_linter.froms[cte._de_clone()] = cte_name
4224
4225 if not is_new_cte and embedded_in_current_named_cte:
4226 return self.preparer.format_alias(cte, cte_name)
4227
4228 if cte_pre_alias_name:
4229 text = self.preparer.format_alias(cte, cte_pre_alias_name)
4230 if self.preparer._requires_quotes(cte_name):
4231 cte_name = self.preparer.quote(cte_name)
4232 text += self.get_render_as_alias_suffix(cte_name)
4233 return text
4234 else:
4235 return self.preparer.format_alias(cte, cte_name)
4236
4237 return None
4238
4239 def visit_table_valued_alias(self, element, **kw):
4240 if element.joins_implicitly:
4241 kw["from_linter"] = None
4242 if element._is_lateral:
4243 return self.visit_lateral(element, **kw)
4244 else:
4245 return self.visit_alias(element, **kw)
4246
4247 def visit_table_valued_column(self, element, **kw):
4248 return self.visit_column(element, **kw)
4249
4250 def visit_alias(
4251 self,
4252 alias,
4253 asfrom=False,
4254 ashint=False,
4255 iscrud=False,
4256 fromhints=None,
4257 subquery=False,
4258 lateral=False,
4259 enclosing_alias=None,
4260 from_linter=None,
4261 **kwargs,
4262 ):
4263 if lateral:
4264 if "enclosing_lateral" not in kwargs:
4265 # if lateral is set and enclosing_lateral is not
4266 # present, we assume we are being called directly
4267 # from visit_lateral() and we need to set enclosing_lateral.
4268 assert alias._is_lateral
4269 kwargs["enclosing_lateral"] = alias
4270
4271 # for lateral objects, we track a second from_linter that is...
4272 # lateral! to the level above us.
4273 if (
4274 from_linter
4275 and "lateral_from_linter" not in kwargs
4276 and "enclosing_lateral" in kwargs
4277 ):
4278 kwargs["lateral_from_linter"] = from_linter
4279
4280 if enclosing_alias is not None and enclosing_alias.element is alias:
4281 inner = alias.element._compiler_dispatch(
4282 self,
4283 asfrom=asfrom,
4284 ashint=ashint,
4285 iscrud=iscrud,
4286 fromhints=fromhints,
4287 lateral=lateral,
4288 enclosing_alias=alias,
4289 **kwargs,
4290 )
4291 if subquery and (asfrom or lateral):
4292 inner = "(%s)" % (inner,)
4293 return inner
4294 else:
4295 kwargs["enclosing_alias"] = alias
4296
4297 if asfrom or ashint:
4298 if isinstance(alias.name, elements._truncated_label):
4299 alias_name = self._truncated_identifier("alias", alias.name)
4300 else:
4301 alias_name = alias.name
4302
4303 if ashint:
4304 return self.preparer.format_alias(alias, alias_name)
4305 elif asfrom:
4306 if from_linter:
4307 from_linter.froms[alias._de_clone()] = alias_name
4308
4309 inner = alias.element._compiler_dispatch(
4310 self, asfrom=True, lateral=lateral, **kwargs
4311 )
4312 if subquery:
4313 inner = "(%s)" % (inner,)
4314
4315 ret = inner + self.get_render_as_alias_suffix(
4316 self.preparer.format_alias(alias, alias_name)
4317 )
4318
4319 if alias._supports_derived_columns and alias._render_derived:
4320 ret += "(%s)" % (
4321 ", ".join(
4322 "%s%s"
4323 % (
4324 self.preparer.quote(col.name),
4325 (
4326 " %s"
4327 % self.dialect.type_compiler_instance.process(
4328 col.type, **kwargs
4329 )
4330 if alias._render_derived_w_types
4331 else ""
4332 ),
4333 )
4334 for col in alias.c
4335 )
4336 )
4337
4338 if fromhints and alias in fromhints:
4339 ret = self.format_from_hint_text(
4340 ret, alias, fromhints[alias], iscrud
4341 )
4342
4343 return ret
4344 else:
4345 # note we cancel the "subquery" flag here as well
4346 return alias.element._compiler_dispatch(
4347 self, lateral=lateral, **kwargs
4348 )
4349
4350 def visit_subquery(self, subquery, **kw):
4351 kw["subquery"] = True
4352 return self.visit_alias(subquery, **kw)
4353
4354 def visit_lateral(self, lateral_, **kw):
4355 kw["lateral"] = True
4356 return "LATERAL %s" % self.visit_alias(lateral_, **kw)
4357
4358 def visit_tablesample(self, tablesample, asfrom=False, **kw):
4359 text = "%s TABLESAMPLE %s" % (
4360 self.visit_alias(tablesample, asfrom=True, **kw),
4361 tablesample._get_method()._compiler_dispatch(self, **kw),
4362 )
4363
4364 if tablesample.seed is not None:
4365 text += " REPEATABLE (%s)" % (
4366 tablesample.seed._compiler_dispatch(self, **kw)
4367 )
4368
4369 return text
4370
4371 def _render_values(self, element, **kw):
4372 kw.setdefault("literal_binds", element.literal_binds)
4373 tuples = ", ".join(
4374 self.process(
4375 elements.Tuple(
4376 types=element._column_types, *elem
4377 ).self_group(),
4378 **kw,
4379 )
4380 for chunk in element._data
4381 for elem in chunk
4382 )
4383 return f"VALUES {tuples}"
4384
4385 def visit_values(self, element, asfrom=False, from_linter=None, **kw):
4386 v = self._render_values(element, **kw)
4387
4388 if element._unnamed:
4389 name = None
4390 elif isinstance(element.name, elements._truncated_label):
4391 name = self._truncated_identifier("values", element.name)
4392 else:
4393 name = element.name
4394
4395 if element._is_lateral:
4396 lateral = "LATERAL "
4397 else:
4398 lateral = ""
4399
4400 if asfrom:
4401 if from_linter:
4402 from_linter.froms[element._de_clone()] = (
4403 name if name is not None else "(unnamed VALUES element)"
4404 )
4405
4406 if name:
4407 kw["include_table"] = False
4408 v = "%s(%s)%s (%s)" % (
4409 lateral,
4410 v,
4411 self.get_render_as_alias_suffix(self.preparer.quote(name)),
4412 (
4413 ", ".join(
4414 c._compiler_dispatch(self, **kw)
4415 for c in element.columns
4416 )
4417 ),
4418 )
4419 else:
4420 v = "%s(%s)" % (lateral, v)
4421 return v
4422
4423 def visit_scalar_values(self, element, **kw):
4424 return f"({self._render_values(element, **kw)})"
4425
4426 def get_render_as_alias_suffix(self, alias_name_text):
4427 return " AS " + alias_name_text
4428
4429 def _add_to_result_map(
4430 self,
4431 keyname: str,
4432 name: str,
4433 objects: Tuple[Any, ...],
4434 type_: TypeEngine[Any],
4435 ) -> None:
4436
4437 # note objects must be non-empty for cursor.py to handle the
4438 # collection properly
4439 assert objects
4440
4441 if keyname is None or keyname == "*":
4442 self._ordered_columns = False
4443 self._ad_hoc_textual = True
4444 if type_._is_tuple_type:
4445 raise exc.CompileError(
4446 "Most backends don't support SELECTing "
4447 "from a tuple() object. If this is an ORM query, "
4448 "consider using the Bundle object."
4449 )
4450 self._result_columns.append(
4451 ResultColumnsEntry(keyname, name, objects, type_)
4452 )
4453
4454 def _label_returning_column(
4455 self, stmt, column, populate_result_map, column_clause_args=None, **kw
4456 ):
4457 """Render a column with necessary labels inside of a RETURNING clause.
4458
4459 This method is provided for individual dialects in place of calling
4460 the _label_select_column method directly, so that the two use cases
4461 of RETURNING vs. SELECT can be disambiguated going forward.
4462
4463 .. versionadded:: 1.4.21
4464
4465 """
4466 return self._label_select_column(
4467 None,
4468 column,
4469 populate_result_map,
4470 False,
4471 {} if column_clause_args is None else column_clause_args,
4472 **kw,
4473 )
4474
4475 def _label_select_column(
4476 self,
4477 select,
4478 column,
4479 populate_result_map,
4480 asfrom,
4481 column_clause_args,
4482 name=None,
4483 proxy_name=None,
4484 fallback_label_name=None,
4485 within_columns_clause=True,
4486 column_is_repeated=False,
4487 need_column_expressions=False,
4488 include_table=True,
4489 ):
4490 """produce labeled columns present in a select()."""
4491 impl = column.type.dialect_impl(self.dialect)
4492
4493 if impl._has_column_expression and (
4494 need_column_expressions or populate_result_map
4495 ):
4496 col_expr = impl.column_expression(column)
4497 else:
4498 col_expr = column
4499
4500 if populate_result_map:
4501 # pass an "add_to_result_map" callable into the compilation
4502 # of embedded columns. this collects information about the
4503 # column as it will be fetched in the result and is coordinated
4504 # with cursor.description when the query is executed.
4505 add_to_result_map = self._add_to_result_map
4506
4507 # if the SELECT statement told us this column is a repeat,
4508 # wrap the callable with one that prevents the addition of the
4509 # targets
4510 if column_is_repeated:
4511 _add_to_result_map = add_to_result_map
4512
4513 def add_to_result_map(keyname, name, objects, type_):
4514 _add_to_result_map(keyname, name, (keyname,), type_)
4515
4516 # if we redefined col_expr for type expressions, wrap the
4517 # callable with one that adds the original column to the targets
4518 elif col_expr is not column:
4519 _add_to_result_map = add_to_result_map
4520
4521 def add_to_result_map(keyname, name, objects, type_):
4522 _add_to_result_map(
4523 keyname, name, (column,) + objects, type_
4524 )
4525
4526 else:
4527 add_to_result_map = None
4528
4529 # this method is used by some of the dialects for RETURNING,
4530 # which has different inputs. _label_returning_column was added
4531 # as the better target for this now however for 1.4 we will keep
4532 # _label_select_column directly compatible with this use case.
4533 # these assertions right now set up the current expected inputs
4534 assert within_columns_clause, (
4535 "_label_select_column is only relevant within "
4536 "the columns clause of a SELECT or RETURNING"
4537 )
4538 if isinstance(column, elements.Label):
4539 if col_expr is not column:
4540 result_expr = _CompileLabel(
4541 col_expr, column.name, alt_names=(column.element,)
4542 )
4543 else:
4544 result_expr = col_expr
4545
4546 elif name:
4547 # here, _columns_plus_names has determined there's an explicit
4548 # label name we need to use. this is the default for
4549 # tablenames_plus_columnnames as well as when columns are being
4550 # deduplicated on name
4551
4552 assert (
4553 proxy_name is not None
4554 ), "proxy_name is required if 'name' is passed"
4555
4556 result_expr = _CompileLabel(
4557 col_expr,
4558 name,
4559 alt_names=(
4560 proxy_name,
4561 # this is a hack to allow legacy result column lookups
4562 # to work as they did before; this goes away in 2.0.
4563 # TODO: this only seems to be tested indirectly
4564 # via test/orm/test_deprecations.py. should be a
4565 # resultset test for this
4566 column._tq_label,
4567 ),
4568 )
4569 else:
4570 # determine here whether this column should be rendered in
4571 # a labelled context or not, as we were given no required label
4572 # name from the caller. Here we apply heuristics based on the kind
4573 # of SQL expression involved.
4574
4575 if col_expr is not column:
4576 # type-specific expression wrapping the given column,
4577 # so we render a label
4578 render_with_label = True
4579 elif isinstance(column, elements.ColumnClause):
4580 # table-bound column, we render its name as a label if we are
4581 # inside of a subquery only
4582 render_with_label = (
4583 asfrom
4584 and not column.is_literal
4585 and column.table is not None
4586 )
4587 elif isinstance(column, elements.TextClause):
4588 render_with_label = False
4589 elif isinstance(column, elements.UnaryExpression):
4590 render_with_label = column.wraps_column_expression or asfrom
4591 elif (
4592 # general class of expressions that don't have a SQL-column
4593 # addressible name. includes scalar selects, bind parameters,
4594 # SQL functions, others
4595 not isinstance(column, elements.NamedColumn)
4596 # deeper check that indicates there's no natural "name" to
4597 # this element, which accommodates for custom SQL constructs
4598 # that might have a ".name" attribute (but aren't SQL
4599 # functions) but are not implementing this more recently added
4600 # base class. in theory the "NamedColumn" check should be
4601 # enough, however here we seek to maintain legacy behaviors
4602 # as well.
4603 and column._non_anon_label is None
4604 ):
4605 render_with_label = True
4606 else:
4607 render_with_label = False
4608
4609 if render_with_label:
4610 if not fallback_label_name:
4611 # used by the RETURNING case right now. we generate it
4612 # here as 3rd party dialects may be referring to
4613 # _label_select_column method directly instead of the
4614 # just-added _label_returning_column method
4615 assert not column_is_repeated
4616 fallback_label_name = column._anon_name_label
4617
4618 fallback_label_name = (
4619 elements._truncated_label(fallback_label_name)
4620 if not isinstance(
4621 fallback_label_name, elements._truncated_label
4622 )
4623 else fallback_label_name
4624 )
4625
4626 result_expr = _CompileLabel(
4627 col_expr, fallback_label_name, alt_names=(proxy_name,)
4628 )
4629 else:
4630 result_expr = col_expr
4631
4632 column_clause_args.update(
4633 within_columns_clause=within_columns_clause,
4634 add_to_result_map=add_to_result_map,
4635 include_table=include_table,
4636 )
4637 return result_expr._compiler_dispatch(self, **column_clause_args)
4638
4639 def format_from_hint_text(self, sqltext, table, hint, iscrud):
4640 hinttext = self.get_from_hint_text(table, hint)
4641 if hinttext:
4642 sqltext += " " + hinttext
4643 return sqltext
4644
4645 def get_select_hint_text(self, byfroms):
4646 return None
4647
4648 def get_from_hint_text(
4649 self, table: FromClause, text: Optional[str]
4650 ) -> Optional[str]:
4651 return None
4652
4653 def get_crud_hint_text(self, table, text):
4654 return None
4655
4656 def get_statement_hint_text(self, hint_texts):
4657 return " ".join(hint_texts)
4658
4659 _default_stack_entry: _CompilerStackEntry
4660
4661 if not typing.TYPE_CHECKING:
4662 _default_stack_entry = util.immutabledict(
4663 [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
4664 )
4665
4666 def _display_froms_for_select(
4667 self, select_stmt, asfrom, lateral=False, **kw
4668 ):
4669 # utility method to help external dialects
4670 # get the correct from list for a select.
4671 # specifically the oracle dialect needs this feature
4672 # right now.
4673 toplevel = not self.stack
4674 entry = self._default_stack_entry if toplevel else self.stack[-1]
4675
4676 compile_state = select_stmt._compile_state_factory(select_stmt, self)
4677
4678 correlate_froms = entry["correlate_froms"]
4679 asfrom_froms = entry["asfrom_froms"]
4680
4681 if asfrom and not lateral:
4682 froms = compile_state._get_display_froms(
4683 explicit_correlate_froms=correlate_froms.difference(
4684 asfrom_froms
4685 ),
4686 implicit_correlate_froms=(),
4687 )
4688 else:
4689 froms = compile_state._get_display_froms(
4690 explicit_correlate_froms=correlate_froms,
4691 implicit_correlate_froms=asfrom_froms,
4692 )
4693 return froms
4694
4695 translate_select_structure: Any = None
4696 """if not ``None``, should be a callable which accepts ``(select_stmt,
4697 **kw)`` and returns a select object. this is used for structural changes
4698 mostly to accommodate for LIMIT/OFFSET schemes
4699
4700 """
4701
4702 def visit_select(
4703 self,
4704 select_stmt,
4705 asfrom=False,
4706 insert_into=False,
4707 fromhints=None,
4708 compound_index=None,
4709 select_wraps_for=None,
4710 lateral=False,
4711 from_linter=None,
4712 **kwargs,
4713 ):
4714 assert select_wraps_for is None, (
4715 "SQLAlchemy 1.4 requires use of "
4716 "the translate_select_structure hook for structural "
4717 "translations of SELECT objects"
4718 )
4719
4720 # initial setup of SELECT. the compile_state_factory may now
4721 # be creating a totally different SELECT from the one that was
4722 # passed in. for ORM use this will convert from an ORM-state
4723 # SELECT to a regular "Core" SELECT. other composed operations
4724 # such as computation of joins will be performed.
4725
4726 kwargs["within_columns_clause"] = False
4727
4728 compile_state = select_stmt._compile_state_factory(
4729 select_stmt, self, **kwargs
4730 )
4731 kwargs["ambiguous_table_name_map"] = (
4732 compile_state._ambiguous_table_name_map
4733 )
4734
4735 select_stmt = compile_state.statement
4736
4737 toplevel = not self.stack
4738
4739 if toplevel and not self.compile_state:
4740 self.compile_state = compile_state
4741
4742 is_embedded_select = compound_index is not None or insert_into
4743
4744 # translate step for Oracle, SQL Server which often need to
4745 # restructure the SELECT to allow for LIMIT/OFFSET and possibly
4746 # other conditions
4747 if self.translate_select_structure:
4748 new_select_stmt = self.translate_select_structure(
4749 select_stmt, asfrom=asfrom, **kwargs
4750 )
4751
4752 # if SELECT was restructured, maintain a link to the originals
4753 # and assemble a new compile state
4754 if new_select_stmt is not select_stmt:
4755 compile_state_wraps_for = compile_state
4756 select_wraps_for = select_stmt
4757 select_stmt = new_select_stmt
4758
4759 compile_state = select_stmt._compile_state_factory(
4760 select_stmt, self, **kwargs
4761 )
4762 select_stmt = compile_state.statement
4763
4764 entry = self._default_stack_entry if toplevel else self.stack[-1]
4765
4766 populate_result_map = need_column_expressions = (
4767 toplevel
4768 or entry.get("need_result_map_for_compound", False)
4769 or entry.get("need_result_map_for_nested", False)
4770 )
4771
4772 # indicates there is a CompoundSelect in play and we are not the
4773 # first select
4774 if compound_index:
4775 populate_result_map = False
4776
4777 # this was first proposed as part of #3372; however, it is not
4778 # reached in current tests and could possibly be an assertion
4779 # instead.
4780 if not populate_result_map and "add_to_result_map" in kwargs:
4781 del kwargs["add_to_result_map"]
4782
4783 froms = self._setup_select_stack(
4784 select_stmt, compile_state, entry, asfrom, lateral, compound_index
4785 )
4786
4787 column_clause_args = kwargs.copy()
4788 column_clause_args.update(
4789 {"within_label_clause": False, "within_columns_clause": False}
4790 )
4791
4792 text = "SELECT " # we're off to a good start !
4793
4794 if select_stmt._hints:
4795 hint_text, byfrom = self._setup_select_hints(select_stmt)
4796 if hint_text:
4797 text += hint_text + " "
4798 else:
4799 byfrom = None
4800
4801 if select_stmt._independent_ctes:
4802 self._dispatch_independent_ctes(select_stmt, kwargs)
4803
4804 if select_stmt._prefixes:
4805 text += self._generate_prefixes(
4806 select_stmt, select_stmt._prefixes, **kwargs
4807 )
4808
4809 text += self.get_select_precolumns(select_stmt, **kwargs)
4810 # the actual list of columns to print in the SELECT column list.
4811 inner_columns = [
4812 c
4813 for c in [
4814 self._label_select_column(
4815 select_stmt,
4816 column,
4817 populate_result_map,
4818 asfrom,
4819 column_clause_args,
4820 name=name,
4821 proxy_name=proxy_name,
4822 fallback_label_name=fallback_label_name,
4823 column_is_repeated=repeated,
4824 need_column_expressions=need_column_expressions,
4825 )
4826 for (
4827 name,
4828 proxy_name,
4829 fallback_label_name,
4830 column,
4831 repeated,
4832 ) in compile_state.columns_plus_names
4833 ]
4834 if c is not None
4835 ]
4836
4837 if populate_result_map and select_wraps_for is not None:
4838 # if this select was generated from translate_select,
4839 # rewrite the targeted columns in the result map
4840
4841 translate = dict(
4842 zip(
4843 [
4844 name
4845 for (
4846 key,
4847 proxy_name,
4848 fallback_label_name,
4849 name,
4850 repeated,
4851 ) in compile_state.columns_plus_names
4852 ],
4853 [
4854 name
4855 for (
4856 key,
4857 proxy_name,
4858 fallback_label_name,
4859 name,
4860 repeated,
4861 ) in compile_state_wraps_for.columns_plus_names
4862 ],
4863 )
4864 )
4865
4866 self._result_columns = [
4867 ResultColumnsEntry(
4868 key, name, tuple(translate.get(o, o) for o in obj), type_
4869 )
4870 for key, name, obj, type_ in self._result_columns
4871 ]
4872
4873 text = self._compose_select_body(
4874 text,
4875 select_stmt,
4876 compile_state,
4877 inner_columns,
4878 froms,
4879 byfrom,
4880 toplevel,
4881 kwargs,
4882 )
4883
4884 if select_stmt._statement_hints:
4885 per_dialect = [
4886 ht
4887 for (dialect_name, ht) in select_stmt._statement_hints
4888 if dialect_name in ("*", self.dialect.name)
4889 ]
4890 if per_dialect:
4891 text += " " + self.get_statement_hint_text(per_dialect)
4892
4893 # In compound query, CTEs are shared at the compound level
4894 if self.ctes and (not is_embedded_select or toplevel):
4895 nesting_level = len(self.stack) if not toplevel else None
4896 text = self._render_cte_clause(nesting_level=nesting_level) + text
4897
4898 if select_stmt._suffixes:
4899 text += " " + self._generate_prefixes(
4900 select_stmt, select_stmt._suffixes, **kwargs
4901 )
4902
4903 self.stack.pop(-1)
4904
4905 return text
4906
4907 def _setup_select_hints(
4908 self, select: Select[Any]
4909 ) -> Tuple[str, _FromHintsType]:
4910 byfrom = {
4911 from_: hinttext
4912 % {"name": from_._compiler_dispatch(self, ashint=True)}
4913 for (from_, dialect), hinttext in select._hints.items()
4914 if dialect in ("*", self.dialect.name)
4915 }
4916 hint_text = self.get_select_hint_text(byfrom)
4917 return hint_text, byfrom
4918
4919 def _setup_select_stack(
4920 self, select, compile_state, entry, asfrom, lateral, compound_index
4921 ):
4922 correlate_froms = entry["correlate_froms"]
4923 asfrom_froms = entry["asfrom_froms"]
4924
4925 if compound_index == 0:
4926 entry["select_0"] = select
4927 elif compound_index:
4928 select_0 = entry["select_0"]
4929 numcols = len(select_0._all_selected_columns)
4930
4931 if len(compile_state.columns_plus_names) != numcols:
4932 raise exc.CompileError(
4933 "All selectables passed to "
4934 "CompoundSelect must have identical numbers of "
4935 "columns; select #%d has %d columns, select "
4936 "#%d has %d"
4937 % (
4938 1,
4939 numcols,
4940 compound_index + 1,
4941 len(select._all_selected_columns),
4942 )
4943 )
4944
4945 if asfrom and not lateral:
4946 froms = compile_state._get_display_froms(
4947 explicit_correlate_froms=correlate_froms.difference(
4948 asfrom_froms
4949 ),
4950 implicit_correlate_froms=(),
4951 )
4952 else:
4953 froms = compile_state._get_display_froms(
4954 explicit_correlate_froms=correlate_froms,
4955 implicit_correlate_froms=asfrom_froms,
4956 )
4957
4958 new_correlate_froms = set(_from_objects(*froms))
4959 all_correlate_froms = new_correlate_froms.union(correlate_froms)
4960
4961 new_entry: _CompilerStackEntry = {
4962 "asfrom_froms": new_correlate_froms,
4963 "correlate_froms": all_correlate_froms,
4964 "selectable": select,
4965 "compile_state": compile_state,
4966 }
4967 self.stack.append(new_entry)
4968
4969 return froms
4970
4971 def _compose_select_body(
4972 self,
4973 text,
4974 select,
4975 compile_state,
4976 inner_columns,
4977 froms,
4978 byfrom,
4979 toplevel,
4980 kwargs,
4981 ):
4982 text += ", ".join(inner_columns)
4983
4984 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
4985 from_linter = FromLinter({}, set())
4986 warn_linting = self.linting & WARN_LINTING
4987 if toplevel:
4988 self.from_linter = from_linter
4989 else:
4990 from_linter = None
4991 warn_linting = False
4992
4993 # adjust the whitespace for no inner columns, part of #9440,
4994 # so that a no-col SELECT comes out as "SELECT WHERE..." or
4995 # "SELECT FROM ...".
4996 # while it would be better to have built the SELECT starting string
4997 # without trailing whitespace first, then add whitespace only if inner
4998 # cols were present, this breaks compatibility with various custom
4999 # compilation schemes that are currently being tested.
5000 if not inner_columns:
5001 text = text.rstrip()
5002
5003 if froms:
5004 text += " \nFROM "
5005
5006 if select._hints:
5007 text += ", ".join(
5008 [
5009 f._compiler_dispatch(
5010 self,
5011 asfrom=True,
5012 fromhints=byfrom,
5013 from_linter=from_linter,
5014 **kwargs,
5015 )
5016 for f in froms
5017 ]
5018 )
5019 else:
5020 text += ", ".join(
5021 [
5022 f._compiler_dispatch(
5023 self,
5024 asfrom=True,
5025 from_linter=from_linter,
5026 **kwargs,
5027 )
5028 for f in froms
5029 ]
5030 )
5031 else:
5032 text += self.default_from()
5033
5034 if select._where_criteria:
5035 t = self._generate_delimited_and_list(
5036 select._where_criteria, from_linter=from_linter, **kwargs
5037 )
5038 if t:
5039 text += " \nWHERE " + t
5040
5041 if warn_linting:
5042 assert from_linter is not None
5043 from_linter.warn()
5044
5045 if select._group_by_clauses:
5046 text += self.group_by_clause(select, **kwargs)
5047
5048 if select._having_criteria:
5049 t = self._generate_delimited_and_list(
5050 select._having_criteria, **kwargs
5051 )
5052 if t:
5053 text += " \nHAVING " + t
5054
5055 if select._order_by_clauses:
5056 text += self.order_by_clause(select, **kwargs)
5057
5058 if select._has_row_limiting_clause:
5059 text += self._row_limit_clause(select, **kwargs)
5060
5061 if select._for_update_arg is not None:
5062 text += self.for_update_clause(select, **kwargs)
5063
5064 return text
5065
5066 def _generate_prefixes(self, stmt, prefixes, **kw):
5067 clause = " ".join(
5068 prefix._compiler_dispatch(self, **kw)
5069 for prefix, dialect_name in prefixes
5070 if dialect_name in (None, "*") or dialect_name == self.dialect.name
5071 )
5072 if clause:
5073 clause += " "
5074 return clause
5075
5076 def _render_cte_clause(
5077 self,
5078 nesting_level=None,
5079 include_following_stack=False,
5080 ):
5081 """
5082 include_following_stack
5083 Also render the nesting CTEs on the next stack. Useful for
5084 SQL structures like UNION or INSERT that can wrap SELECT
5085 statements containing nesting CTEs.
5086 """
5087 if not self.ctes:
5088 return ""
5089
5090 ctes: MutableMapping[CTE, str]
5091
5092 if nesting_level and nesting_level > 1:
5093 ctes = util.OrderedDict()
5094 for cte in list(self.ctes.keys()):
5095 cte_level, cte_name, cte_opts = self.level_name_by_cte[
5096 cte._get_reference_cte()
5097 ]
5098 nesting = cte.nesting or cte_opts.nesting
5099 is_rendered_level = cte_level == nesting_level or (
5100 include_following_stack and cte_level == nesting_level + 1
5101 )
5102 if not (nesting and is_rendered_level):
5103 continue
5104
5105 ctes[cte] = self.ctes[cte]
5106
5107 else:
5108 ctes = self.ctes
5109
5110 if not ctes:
5111 return ""
5112 ctes_recursive = any([cte.recursive for cte in ctes])
5113
5114 cte_text = self.get_cte_preamble(ctes_recursive) + " "
5115 cte_text += ", \n".join([txt for txt in ctes.values()])
5116 cte_text += "\n "
5117
5118 if nesting_level and nesting_level > 1:
5119 for cte in list(ctes.keys()):
5120 cte_level, cte_name, cte_opts = self.level_name_by_cte[
5121 cte._get_reference_cte()
5122 ]
5123 del self.ctes[cte]
5124 del self.ctes_by_level_name[(cte_level, cte_name)]
5125 del self.level_name_by_cte[cte._get_reference_cte()]
5126
5127 return cte_text
5128
5129 def get_cte_preamble(self, recursive):
5130 if recursive:
5131 return "WITH RECURSIVE"
5132 else:
5133 return "WITH"
5134
5135 def get_select_precolumns(self, select: Select[Any], **kw: Any) -> str:
5136 """Called when building a ``SELECT`` statement, position is just
5137 before column list.
5138
5139 """
5140 if select._distinct_on:
5141 util.warn_deprecated(
5142 "DISTINCT ON is currently supported only by the PostgreSQL "
5143 "dialect. Use of DISTINCT ON for other backends is currently "
5144 "silently ignored, however this usage is deprecated, and will "
5145 "raise CompileError in a future release for all backends "
5146 "that do not support this syntax.",
5147 version="1.4",
5148 )
5149 return "DISTINCT " if select._distinct else ""
5150
5151 def group_by_clause(self, select, **kw):
5152 """allow dialects to customize how GROUP BY is rendered."""
5153
5154 group_by = self._generate_delimited_list(
5155 select._group_by_clauses, OPERATORS[operators.comma_op], **kw
5156 )
5157 if group_by:
5158 return " GROUP BY " + group_by
5159 else:
5160 return ""
5161
5162 def order_by_clause(self, select, **kw):
5163 """allow dialects to customize how ORDER BY is rendered."""
5164
5165 order_by = self._generate_delimited_list(
5166 select._order_by_clauses, OPERATORS[operators.comma_op], **kw
5167 )
5168
5169 if order_by:
5170 return " ORDER BY " + order_by
5171 else:
5172 return ""
5173
5174 def for_update_clause(self, select, **kw):
5175 return " FOR UPDATE"
5176
5177 def returning_clause(
5178 self,
5179 stmt: UpdateBase,
5180 returning_cols: Sequence[_ColumnsClauseElement],
5181 *,
5182 populate_result_map: bool,
5183 **kw: Any,
5184 ) -> str:
5185 columns = [
5186 self._label_returning_column(
5187 stmt,
5188 column,
5189 populate_result_map,
5190 fallback_label_name=fallback_label_name,
5191 column_is_repeated=repeated,
5192 name=name,
5193 proxy_name=proxy_name,
5194 **kw,
5195 )
5196 for (
5197 name,
5198 proxy_name,
5199 fallback_label_name,
5200 column,
5201 repeated,
5202 ) in stmt._generate_columns_plus_names(
5203 True, cols=base._select_iterables(returning_cols)
5204 )
5205 ]
5206
5207 return "RETURNING " + ", ".join(columns)
5208
5209 def limit_clause(self, select, **kw):
5210 text = ""
5211 if select._limit_clause is not None:
5212 text += "\n LIMIT " + self.process(select._limit_clause, **kw)
5213 if select._offset_clause is not None:
5214 if select._limit_clause is None:
5215 text += "\n LIMIT -1"
5216 text += " OFFSET " + self.process(select._offset_clause, **kw)
5217 return text
5218
5219 def fetch_clause(
5220 self,
5221 select,
5222 fetch_clause=None,
5223 require_offset=False,
5224 use_literal_execute_for_simple_int=False,
5225 **kw,
5226 ):
5227 if fetch_clause is None:
5228 fetch_clause = select._fetch_clause
5229 fetch_clause_options = select._fetch_clause_options
5230 else:
5231 fetch_clause_options = {"percent": False, "with_ties": False}
5232
5233 text = ""
5234
5235 if select._offset_clause is not None:
5236 offset_clause = select._offset_clause
5237 if (
5238 use_literal_execute_for_simple_int
5239 and select._simple_int_clause(offset_clause)
5240 ):
5241 offset_clause = offset_clause.render_literal_execute()
5242 offset_str = self.process(offset_clause, **kw)
5243 text += "\n OFFSET %s ROWS" % offset_str
5244 elif require_offset:
5245 text += "\n OFFSET 0 ROWS"
5246
5247 if fetch_clause is not None:
5248 if (
5249 use_literal_execute_for_simple_int
5250 and select._simple_int_clause(fetch_clause)
5251 ):
5252 fetch_clause = fetch_clause.render_literal_execute()
5253 text += "\n FETCH FIRST %s%s ROWS %s" % (
5254 self.process(fetch_clause, **kw),
5255 " PERCENT" if fetch_clause_options["percent"] else "",
5256 "WITH TIES" if fetch_clause_options["with_ties"] else "ONLY",
5257 )
5258 return text
5259
5260 def visit_table(
5261 self,
5262 table,
5263 asfrom=False,
5264 iscrud=False,
5265 ashint=False,
5266 fromhints=None,
5267 use_schema=True,
5268 from_linter=None,
5269 ambiguous_table_name_map=None,
5270 enclosing_alias=None,
5271 **kwargs,
5272 ):
5273 if from_linter:
5274 from_linter.froms[table] = table.fullname
5275
5276 if asfrom or ashint:
5277 effective_schema = self.preparer.schema_for_object(table)
5278
5279 if use_schema and effective_schema:
5280 ret = (
5281 self.preparer.quote_schema(effective_schema)
5282 + "."
5283 + self.preparer.quote(table.name)
5284 )
5285 else:
5286 ret = self.preparer.quote(table.name)
5287
5288 if (
5289 (
5290 enclosing_alias is None
5291 or enclosing_alias.element is not table
5292 )
5293 and not effective_schema
5294 and ambiguous_table_name_map
5295 and table.name in ambiguous_table_name_map
5296 ):
5297 anon_name = self._truncated_identifier(
5298 "alias", ambiguous_table_name_map[table.name]
5299 )
5300
5301 ret = ret + self.get_render_as_alias_suffix(
5302 self.preparer.format_alias(None, anon_name)
5303 )
5304
5305 if fromhints and table in fromhints:
5306 ret = self.format_from_hint_text(
5307 ret, table, fromhints[table], iscrud
5308 )
5309 return ret
5310 else:
5311 return ""
5312
5313 def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
5314 if from_linter:
5315 from_linter.edges.update(
5316 itertools.product(
5317 _de_clone(join.left._from_objects),
5318 _de_clone(join.right._from_objects),
5319 )
5320 )
5321
5322 if join.full:
5323 join_type = " FULL OUTER JOIN "
5324 elif join.isouter:
5325 join_type = " LEFT OUTER JOIN "
5326 else:
5327 join_type = " JOIN "
5328 return (
5329 join.left._compiler_dispatch(
5330 self, asfrom=True, from_linter=from_linter, **kwargs
5331 )
5332 + join_type
5333 + join.right._compiler_dispatch(
5334 self, asfrom=True, from_linter=from_linter, **kwargs
5335 )
5336 + " ON "
5337 # TODO: likely need asfrom=True here?
5338 + join.onclause._compiler_dispatch(
5339 self, from_linter=from_linter, **kwargs
5340 )
5341 )
5342
5343 def _setup_crud_hints(self, stmt, table_text):
5344 dialect_hints = {
5345 table: hint_text
5346 for (table, dialect), hint_text in stmt._hints.items()
5347 if dialect in ("*", self.dialect.name)
5348 }
5349 if stmt.table in dialect_hints:
5350 table_text = self.format_from_hint_text(
5351 table_text, stmt.table, dialect_hints[stmt.table], True
5352 )
5353 return dialect_hints, table_text
5354
5355 # within the realm of "insertmanyvalues sentinel columns",
5356 # these lookups match different kinds of Column() configurations
5357 # to specific backend capabilities. they are broken into two
5358 # lookups, one for autoincrement columns and the other for non
5359 # autoincrement columns
5360 _sentinel_col_non_autoinc_lookup = util.immutabledict(
5361 {
5362 _SentinelDefaultCharacterization.CLIENTSIDE: (
5363 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5364 ),
5365 _SentinelDefaultCharacterization.SENTINEL_DEFAULT: (
5366 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5367 ),
5368 _SentinelDefaultCharacterization.NONE: (
5369 InsertmanyvaluesSentinelOpts._SUPPORTED_OR_NOT
5370 ),
5371 _SentinelDefaultCharacterization.IDENTITY: (
5372 InsertmanyvaluesSentinelOpts.IDENTITY
5373 ),
5374 _SentinelDefaultCharacterization.SEQUENCE: (
5375 InsertmanyvaluesSentinelOpts.SEQUENCE
5376 ),
5377 }
5378 )
5379 _sentinel_col_autoinc_lookup = _sentinel_col_non_autoinc_lookup.union(
5380 {
5381 _SentinelDefaultCharacterization.NONE: (
5382 InsertmanyvaluesSentinelOpts.AUTOINCREMENT
5383 ),
5384 }
5385 )
5386
5387 def _get_sentinel_column_for_table(
5388 self, table: Table
5389 ) -> Optional[Sequence[Column[Any]]]:
5390 """given a :class:`.Table`, return a usable sentinel column or
5391 columns for this dialect if any.
5392
5393 Return None if no sentinel columns could be identified, or raise an
5394 error if a column was marked as a sentinel explicitly but isn't
5395 compatible with this dialect.
5396
5397 """
5398
5399 sentinel_opts = self.dialect.insertmanyvalues_implicit_sentinel
5400 sentinel_characteristics = table._sentinel_column_characteristics
5401
5402 sent_cols = sentinel_characteristics.columns
5403
5404 if sent_cols is None:
5405 return None
5406
5407 if sentinel_characteristics.is_autoinc:
5408 bitmask = self._sentinel_col_autoinc_lookup.get(
5409 sentinel_characteristics.default_characterization, 0
5410 )
5411 else:
5412 bitmask = self._sentinel_col_non_autoinc_lookup.get(
5413 sentinel_characteristics.default_characterization, 0
5414 )
5415
5416 if sentinel_opts & bitmask:
5417 return sent_cols
5418
5419 if sentinel_characteristics.is_explicit:
5420 # a column was explicitly marked as insert_sentinel=True,
5421 # however it is not compatible with this dialect. they should
5422 # not indicate this column as a sentinel if they need to include
5423 # this dialect.
5424
5425 # TODO: do we want non-primary key explicit sentinel cols
5426 # that can gracefully degrade for some backends?
5427 # insert_sentinel="degrade" perhaps. not for the initial release.
5428 # I am hoping people are generally not dealing with this sentinel
5429 # business at all.
5430
5431 # if is_explicit is True, there will be only one sentinel column.
5432
5433 raise exc.InvalidRequestError(
5434 f"Column {sent_cols[0]} can't be explicitly "
5435 "marked as a sentinel column when using the "
5436 f"{self.dialect.name} dialect, as the "
5437 "particular type of default generation on this column is "
5438 "not currently compatible with this dialect's specific "
5439 f"INSERT..RETURNING syntax which can receive the "
5440 "server-generated value in "
5441 "a deterministic way. To remove this error, remove "
5442 "insert_sentinel=True from primary key autoincrement "
5443 "columns; these columns are automatically used as "
5444 "sentinels for supported dialects in any case."
5445 )
5446
5447 return None
5448
5449 def _deliver_insertmanyvalues_batches(
5450 self,
5451 statement: str,
5452 parameters: _DBAPIMultiExecuteParams,
5453 compiled_parameters: List[_MutableCoreSingleExecuteParams],
5454 generic_setinputsizes: Optional[_GenericSetInputSizesType],
5455 batch_size: int,
5456 sort_by_parameter_order: bool,
5457 schema_translate_map: Optional[SchemaTranslateMapType],
5458 ) -> Iterator[_InsertManyValuesBatch]:
5459 imv = self._insertmanyvalues
5460 assert imv is not None
5461
5462 if not imv.sentinel_param_keys:
5463 _sentinel_from_params = None
5464 else:
5465 _sentinel_from_params = operator.itemgetter(
5466 *imv.sentinel_param_keys
5467 )
5468
5469 lenparams = len(parameters)
5470 if imv.is_default_expr and not self.dialect.supports_default_metavalue:
5471 # backend doesn't support
5472 # INSERT INTO table (pk_col) VALUES (DEFAULT), (DEFAULT), ...
5473 # at the moment this is basically SQL Server due to
5474 # not being able to use DEFAULT for identity column
5475 # just yield out that many single statements! still
5476 # faster than a whole connection.execute() call ;)
5477 #
5478 # note we still are taking advantage of the fact that we know
5479 # we are using RETURNING. The generalized approach of fetching
5480 # cursor.lastrowid etc. still goes through the more heavyweight
5481 # "ExecutionContext per statement" system as it isn't usable
5482 # as a generic "RETURNING" approach
5483 use_row_at_a_time = True
5484 downgraded = False
5485 elif not self.dialect.supports_multivalues_insert or (
5486 sort_by_parameter_order
5487 and self._result_columns
5488 and (imv.sentinel_columns is None or imv.includes_upsert_behaviors)
5489 ):
5490 # deterministic order was requested and the compiler could
5491 # not organize sentinel columns for this dialect/statement.
5492 # use row at a time
5493 use_row_at_a_time = True
5494 downgraded = True
5495 else:
5496 use_row_at_a_time = False
5497 downgraded = False
5498
5499 if use_row_at_a_time:
5500 for batchnum, (param, compiled_param) in enumerate(
5501 cast(
5502 "Sequence[Tuple[_DBAPISingleExecuteParams, _MutableCoreSingleExecuteParams]]", # noqa: E501
5503 zip(parameters, compiled_parameters),
5504 ),
5505 1,
5506 ):
5507 yield _InsertManyValuesBatch(
5508 statement,
5509 param,
5510 generic_setinputsizes,
5511 [param],
5512 (
5513 [_sentinel_from_params(compiled_param)]
5514 if _sentinel_from_params
5515 else []
5516 ),
5517 1,
5518 batchnum,
5519 lenparams,
5520 sort_by_parameter_order,
5521 downgraded,
5522 )
5523 return
5524
5525 if schema_translate_map:
5526 rst = functools.partial(
5527 self.preparer._render_schema_translates,
5528 schema_translate_map=schema_translate_map,
5529 )
5530 else:
5531 rst = None
5532
5533 imv_single_values_expr = imv.single_values_expr
5534 if rst:
5535 imv_single_values_expr = rst(imv_single_values_expr)
5536
5537 executemany_values = f"({imv_single_values_expr})"
5538 statement = statement.replace(executemany_values, "__EXECMANY_TOKEN__")
5539
5540 # Use optional insertmanyvalues_max_parameters
5541 # to further shrink the batch size so that there are no more than
5542 # insertmanyvalues_max_parameters params.
5543 # Currently used by SQL Server, which limits statements to 2100 bound
5544 # parameters (actually 2099).
5545 max_params = self.dialect.insertmanyvalues_max_parameters
5546 if max_params:
5547 total_num_of_params = len(self.bind_names)
5548 num_params_per_batch = len(imv.insert_crud_params)
5549 num_params_outside_of_batch = (
5550 total_num_of_params - num_params_per_batch
5551 )
5552 batch_size = min(
5553 batch_size,
5554 (
5555 (max_params - num_params_outside_of_batch)
5556 // num_params_per_batch
5557 ),
5558 )
5559
5560 batches = cast("List[Sequence[Any]]", list(parameters))
5561 compiled_batches = cast(
5562 "List[Sequence[Any]]", list(compiled_parameters)
5563 )
5564
5565 processed_setinputsizes: Optional[_GenericSetInputSizesType] = None
5566 batchnum = 1
5567 total_batches = lenparams // batch_size + (
5568 1 if lenparams % batch_size else 0
5569 )
5570
5571 insert_crud_params = imv.insert_crud_params
5572 assert insert_crud_params is not None
5573
5574 if rst:
5575 insert_crud_params = [
5576 (col, key, rst(expr), st)
5577 for col, key, expr, st in insert_crud_params
5578 ]
5579
5580 escaped_bind_names: Mapping[str, str]
5581 expand_pos_lower_index = expand_pos_upper_index = 0
5582
5583 if not self.positional:
5584 if self.escaped_bind_names:
5585 escaped_bind_names = self.escaped_bind_names
5586 else:
5587 escaped_bind_names = {}
5588
5589 all_keys = set(parameters[0])
5590
5591 def apply_placeholders(keys, formatted):
5592 for key in keys:
5593 key = escaped_bind_names.get(key, key)
5594 formatted = formatted.replace(
5595 self.bindtemplate % {"name": key},
5596 self.bindtemplate
5597 % {"name": f"{key}__EXECMANY_INDEX__"},
5598 )
5599 return formatted
5600
5601 if imv.embed_values_counter:
5602 imv_values_counter = ", _IMV_VALUES_COUNTER"
5603 else:
5604 imv_values_counter = ""
5605 formatted_values_clause = f"""({', '.join(
5606 apply_placeholders(bind_keys, formatted)
5607 for _, _, formatted, bind_keys in insert_crud_params
5608 )}{imv_values_counter})"""
5609
5610 keys_to_replace = all_keys.intersection(
5611 escaped_bind_names.get(key, key)
5612 for _, _, _, bind_keys in insert_crud_params
5613 for key in bind_keys
5614 )
5615 base_parameters = {
5616 key: parameters[0][key]
5617 for key in all_keys.difference(keys_to_replace)
5618 }
5619 executemany_values_w_comma = ""
5620 else:
5621 formatted_values_clause = ""
5622 keys_to_replace = set()
5623 base_parameters = {}
5624
5625 if imv.embed_values_counter:
5626 executemany_values_w_comma = (
5627 f"({imv_single_values_expr}, _IMV_VALUES_COUNTER), "
5628 )
5629 else:
5630 executemany_values_w_comma = f"({imv_single_values_expr}), "
5631
5632 all_names_we_will_expand: Set[str] = set()
5633 for elem in imv.insert_crud_params:
5634 all_names_we_will_expand.update(elem[3])
5635
5636 # get the start and end position in a particular list
5637 # of parameters where we will be doing the "expanding".
5638 # statements can have params on either side or both sides,
5639 # given RETURNING and CTEs
5640 if all_names_we_will_expand:
5641 positiontup = self.positiontup
5642 assert positiontup is not None
5643
5644 all_expand_positions = {
5645 idx
5646 for idx, name in enumerate(positiontup)
5647 if name in all_names_we_will_expand
5648 }
5649 expand_pos_lower_index = min(all_expand_positions)
5650 expand_pos_upper_index = max(all_expand_positions) + 1
5651 assert (
5652 len(all_expand_positions)
5653 == expand_pos_upper_index - expand_pos_lower_index
5654 )
5655
5656 if self._numeric_binds:
5657 escaped = re.escape(self._numeric_binds_identifier_char)
5658 executemany_values_w_comma = re.sub(
5659 rf"{escaped}\d+", "%s", executemany_values_w_comma
5660 )
5661
5662 while batches:
5663 batch = batches[0:batch_size]
5664 compiled_batch = compiled_batches[0:batch_size]
5665
5666 batches[0:batch_size] = []
5667 compiled_batches[0:batch_size] = []
5668
5669 if batches:
5670 current_batch_size = batch_size
5671 else:
5672 current_batch_size = len(batch)
5673
5674 if generic_setinputsizes:
5675 # if setinputsizes is present, expand this collection to
5676 # suit the batch length as well
5677 # currently this will be mssql+pyodbc for internal dialects
5678 processed_setinputsizes = [
5679 (new_key, len_, typ)
5680 for new_key, len_, typ in (
5681 (f"{key}_{index}", len_, typ)
5682 for index in range(current_batch_size)
5683 for key, len_, typ in generic_setinputsizes
5684 )
5685 ]
5686
5687 replaced_parameters: Any
5688 if self.positional:
5689 num_ins_params = imv.num_positional_params_counted
5690
5691 batch_iterator: Iterable[Sequence[Any]]
5692 extra_params_left: Sequence[Any]
5693 extra_params_right: Sequence[Any]
5694
5695 if num_ins_params == len(batch[0]):
5696 extra_params_left = extra_params_right = ()
5697 batch_iterator = batch
5698 else:
5699 extra_params_left = batch[0][:expand_pos_lower_index]
5700 extra_params_right = batch[0][expand_pos_upper_index:]
5701 batch_iterator = (
5702 b[expand_pos_lower_index:expand_pos_upper_index]
5703 for b in batch
5704 )
5705
5706 if imv.embed_values_counter:
5707 expanded_values_string = (
5708 "".join(
5709 executemany_values_w_comma.replace(
5710 "_IMV_VALUES_COUNTER", str(i)
5711 )
5712 for i, _ in enumerate(batch)
5713 )
5714 )[:-2]
5715 else:
5716 expanded_values_string = (
5717 (executemany_values_w_comma * current_batch_size)
5718 )[:-2]
5719
5720 if self._numeric_binds and num_ins_params > 0:
5721 # numeric will always number the parameters inside of
5722 # VALUES (and thus order self.positiontup) to be higher
5723 # than non-VALUES parameters, no matter where in the
5724 # statement those non-VALUES parameters appear (this is
5725 # ensured in _process_numeric by numbering first all
5726 # params that are not in _values_bindparam)
5727 # therefore all extra params are always
5728 # on the left side and numbered lower than the VALUES
5729 # parameters
5730 assert not extra_params_right
5731
5732 start = expand_pos_lower_index + 1
5733 end = num_ins_params * (current_batch_size) + start
5734
5735 # need to format here, since statement may contain
5736 # unescaped %, while values_string contains just (%s, %s)
5737 positions = tuple(
5738 f"{self._numeric_binds_identifier_char}{i}"
5739 for i in range(start, end)
5740 )
5741 expanded_values_string = expanded_values_string % positions
5742
5743 replaced_statement = statement.replace(
5744 "__EXECMANY_TOKEN__", expanded_values_string
5745 )
5746
5747 replaced_parameters = tuple(
5748 itertools.chain.from_iterable(batch_iterator)
5749 )
5750
5751 replaced_parameters = (
5752 extra_params_left
5753 + replaced_parameters
5754 + extra_params_right
5755 )
5756
5757 else:
5758 replaced_values_clauses = []
5759 replaced_parameters = base_parameters.copy()
5760
5761 for i, param in enumerate(batch):
5762 fmv = formatted_values_clause.replace(
5763 "EXECMANY_INDEX__", str(i)
5764 )
5765 if imv.embed_values_counter:
5766 fmv = fmv.replace("_IMV_VALUES_COUNTER", str(i))
5767
5768 replaced_values_clauses.append(fmv)
5769 replaced_parameters.update(
5770 {f"{key}__{i}": param[key] for key in keys_to_replace}
5771 )
5772
5773 replaced_statement = statement.replace(
5774 "__EXECMANY_TOKEN__",
5775 ", ".join(replaced_values_clauses),
5776 )
5777
5778 yield _InsertManyValuesBatch(
5779 replaced_statement,
5780 replaced_parameters,
5781 processed_setinputsizes,
5782 batch,
5783 (
5784 [_sentinel_from_params(cb) for cb in compiled_batch]
5785 if _sentinel_from_params
5786 else []
5787 ),
5788 current_batch_size,
5789 batchnum,
5790 total_batches,
5791 sort_by_parameter_order,
5792 False,
5793 )
5794 batchnum += 1
5795
5796 def visit_insert(
5797 self, insert_stmt, visited_bindparam=None, visiting_cte=None, **kw
5798 ):
5799 compile_state = insert_stmt._compile_state_factory(
5800 insert_stmt, self, **kw
5801 )
5802 insert_stmt = compile_state.statement
5803
5804 if visiting_cte is not None:
5805 kw["visiting_cte"] = visiting_cte
5806 toplevel = False
5807 else:
5808 toplevel = not self.stack
5809
5810 if toplevel:
5811 self.isinsert = True
5812 if not self.dml_compile_state:
5813 self.dml_compile_state = compile_state
5814 if not self.compile_state:
5815 self.compile_state = compile_state
5816
5817 self.stack.append(
5818 {
5819 "correlate_froms": set(),
5820 "asfrom_froms": set(),
5821 "selectable": insert_stmt,
5822 }
5823 )
5824
5825 counted_bindparam = 0
5826
5827 # reset any incoming "visited_bindparam" collection
5828 visited_bindparam = None
5829
5830 # for positional, insertmanyvalues needs to know how many
5831 # bound parameters are in the VALUES sequence; there's no simple
5832 # rule because default expressions etc. can have zero or more
5833 # params inside them. After multiple attempts to figure this out,
5834 # this very simplistic "count after" works and is
5835 # likely the least amount of callcounts, though looks clumsy
5836 if self.positional and visiting_cte is None:
5837 # if we are inside a CTE, don't count parameters
5838 # here since they wont be for insertmanyvalues. keep
5839 # visited_bindparam at None so no counting happens.
5840 # see #9173
5841 visited_bindparam = []
5842
5843 crud_params_struct = crud._get_crud_params(
5844 self,
5845 insert_stmt,
5846 compile_state,
5847 toplevel,
5848 visited_bindparam=visited_bindparam,
5849 **kw,
5850 )
5851
5852 if self.positional and visited_bindparam is not None:
5853 counted_bindparam = len(visited_bindparam)
5854 if self._numeric_binds:
5855 if self._values_bindparam is not None:
5856 self._values_bindparam += visited_bindparam
5857 else:
5858 self._values_bindparam = visited_bindparam
5859
5860 crud_params_single = crud_params_struct.single_params
5861
5862 if (
5863 not crud_params_single
5864 and not self.dialect.supports_default_values
5865 and not self.dialect.supports_default_metavalue
5866 and not self.dialect.supports_empty_insert
5867 ):
5868 raise exc.CompileError(
5869 "The '%s' dialect with current database "
5870 "version settings does not support empty "
5871 "inserts." % self.dialect.name
5872 )
5873
5874 if compile_state._has_multi_parameters:
5875 if not self.dialect.supports_multivalues_insert:
5876 raise exc.CompileError(
5877 "The '%s' dialect with current database "
5878 "version settings does not support "
5879 "in-place multirow inserts." % self.dialect.name
5880 )
5881 elif (
5882 self.implicit_returning or insert_stmt._returning
5883 ) and insert_stmt._sort_by_parameter_order:
5884 raise exc.CompileError(
5885 "RETURNING cannot be determinstically sorted when "
5886 "using an INSERT which includes multi-row values()."
5887 )
5888 crud_params_single = crud_params_struct.single_params
5889 else:
5890 crud_params_single = crud_params_struct.single_params
5891
5892 preparer = self.preparer
5893 supports_default_values = self.dialect.supports_default_values
5894
5895 text = "INSERT "
5896
5897 if insert_stmt._prefixes:
5898 text += self._generate_prefixes(
5899 insert_stmt, insert_stmt._prefixes, **kw
5900 )
5901
5902 text += "INTO "
5903 table_text = preparer.format_table(insert_stmt.table)
5904
5905 if insert_stmt._hints:
5906 _, table_text = self._setup_crud_hints(insert_stmt, table_text)
5907
5908 if insert_stmt._independent_ctes:
5909 self._dispatch_independent_ctes(insert_stmt, kw)
5910
5911 text += table_text
5912
5913 if crud_params_single or not supports_default_values:
5914 text += " (%s)" % ", ".join(
5915 [expr for _, expr, _, _ in crud_params_single]
5916 )
5917
5918 # look for insertmanyvalues attributes that would have been configured
5919 # by crud.py as it scanned through the columns to be part of the
5920 # INSERT
5921 use_insertmanyvalues = crud_params_struct.use_insertmanyvalues
5922 named_sentinel_params: Optional[Sequence[str]] = None
5923 add_sentinel_cols = None
5924 implicit_sentinel = False
5925
5926 returning_cols = self.implicit_returning or insert_stmt._returning
5927 if returning_cols:
5928 add_sentinel_cols = crud_params_struct.use_sentinel_columns
5929 if add_sentinel_cols is not None:
5930 assert use_insertmanyvalues
5931
5932 # search for the sentinel column explicitly present
5933 # in the INSERT columns list, and additionally check that
5934 # this column has a bound parameter name set up that's in the
5935 # parameter list. If both of these cases are present, it means
5936 # we will have a client side value for the sentinel in each
5937 # parameter set.
5938
5939 _params_by_col = {
5940 col: param_names
5941 for col, _, _, param_names in crud_params_single
5942 }
5943 named_sentinel_params = []
5944 for _add_sentinel_col in add_sentinel_cols:
5945 if _add_sentinel_col not in _params_by_col:
5946 named_sentinel_params = None
5947 break
5948 param_name = self._within_exec_param_key_getter(
5949 _add_sentinel_col
5950 )
5951 if param_name not in _params_by_col[_add_sentinel_col]:
5952 named_sentinel_params = None
5953 break
5954 named_sentinel_params.append(param_name)
5955
5956 if named_sentinel_params is None:
5957 # if we are not going to have a client side value for
5958 # the sentinel in the parameter set, that means it's
5959 # an autoincrement, an IDENTITY, or a server-side SQL
5960 # expression like nextval('seqname'). So this is
5961 # an "implicit" sentinel; we will look for it in
5962 # RETURNING
5963 # only, and then sort on it. For this case on PG,
5964 # SQL Server we have to use a special INSERT form
5965 # that guarantees the server side function lines up with
5966 # the entries in the VALUES.
5967 if (
5968 self.dialect.insertmanyvalues_implicit_sentinel
5969 & InsertmanyvaluesSentinelOpts.ANY_AUTOINCREMENT
5970 ):
5971 implicit_sentinel = True
5972 else:
5973 # here, we are not using a sentinel at all
5974 # and we are likely the SQLite dialect.
5975 # The first add_sentinel_col that we have should not
5976 # be marked as "insert_sentinel=True". if it was,
5977 # an error should have been raised in
5978 # _get_sentinel_column_for_table.
5979 assert not add_sentinel_cols[0]._insert_sentinel, (
5980 "sentinel selection rules should have prevented "
5981 "us from getting here for this dialect"
5982 )
5983
5984 # always put the sentinel columns last. even if they are
5985 # in the returning list already, they will be there twice
5986 # then.
5987 returning_cols = list(returning_cols) + list(add_sentinel_cols)
5988
5989 returning_clause = self.returning_clause(
5990 insert_stmt,
5991 returning_cols,
5992 populate_result_map=toplevel,
5993 )
5994
5995 if self.returning_precedes_values:
5996 text += " " + returning_clause
5997
5998 else:
5999 returning_clause = None
6000
6001 if insert_stmt.select is not None:
6002 # placed here by crud.py
6003 select_text = self.process(
6004 self.stack[-1]["insert_from_select"], insert_into=True, **kw
6005 )
6006
6007 if self.ctes and self.dialect.cte_follows_insert:
6008 nesting_level = len(self.stack) if not toplevel else None
6009 text += " %s%s" % (
6010 self._render_cte_clause(
6011 nesting_level=nesting_level,
6012 include_following_stack=True,
6013 ),
6014 select_text,
6015 )
6016 else:
6017 text += " %s" % select_text
6018 elif not crud_params_single and supports_default_values:
6019 text += " DEFAULT VALUES"
6020 if use_insertmanyvalues:
6021 self._insertmanyvalues = _InsertManyValues(
6022 True,
6023 self.dialect.default_metavalue_token,
6024 cast(
6025 "List[crud._CrudParamElementStr]", crud_params_single
6026 ),
6027 counted_bindparam,
6028 sort_by_parameter_order=(
6029 insert_stmt._sort_by_parameter_order
6030 ),
6031 includes_upsert_behaviors=(
6032 insert_stmt._post_values_clause is not None
6033 ),
6034 sentinel_columns=add_sentinel_cols,
6035 num_sentinel_columns=(
6036 len(add_sentinel_cols) if add_sentinel_cols else 0
6037 ),
6038 implicit_sentinel=implicit_sentinel,
6039 )
6040 elif compile_state._has_multi_parameters:
6041 text += " VALUES %s" % (
6042 ", ".join(
6043 "(%s)"
6044 % (", ".join(value for _, _, value, _ in crud_param_set))
6045 for crud_param_set in crud_params_struct.all_multi_params
6046 ),
6047 )
6048 else:
6049 insert_single_values_expr = ", ".join(
6050 [
6051 value
6052 for _, _, value, _ in cast(
6053 "List[crud._CrudParamElementStr]",
6054 crud_params_single,
6055 )
6056 ]
6057 )
6058
6059 if use_insertmanyvalues:
6060 if (
6061 implicit_sentinel
6062 and (
6063 self.dialect.insertmanyvalues_implicit_sentinel
6064 & InsertmanyvaluesSentinelOpts.USE_INSERT_FROM_SELECT
6065 )
6066 # this is checking if we have
6067 # INSERT INTO table (id) VALUES (DEFAULT).
6068 and not (crud_params_struct.is_default_metavalue_only)
6069 ):
6070 # if we have a sentinel column that is server generated,
6071 # then for selected backends render the VALUES list as a
6072 # subquery. This is the orderable form supported by
6073 # PostgreSQL and SQL Server.
6074 embed_sentinel_value = True
6075
6076 render_bind_casts = (
6077 self.dialect.insertmanyvalues_implicit_sentinel
6078 & InsertmanyvaluesSentinelOpts.RENDER_SELECT_COL_CASTS
6079 )
6080
6081 colnames = ", ".join(
6082 f"p{i}" for i, _ in enumerate(crud_params_single)
6083 )
6084
6085 if render_bind_casts:
6086 # render casts for the SELECT list. For PG, we are
6087 # already rendering bind casts in the parameter list,
6088 # selectively for the more "tricky" types like ARRAY.
6089 # however, even for the "easy" types, if the parameter
6090 # is NULL for every entry, PG gives up and says
6091 # "it must be TEXT", which fails for other easy types
6092 # like ints. So we cast on this side too.
6093 colnames_w_cast = ", ".join(
6094 self.render_bind_cast(
6095 col.type,
6096 col.type._unwrapped_dialect_impl(self.dialect),
6097 f"p{i}",
6098 )
6099 for i, (col, *_) in enumerate(crud_params_single)
6100 )
6101 else:
6102 colnames_w_cast = colnames
6103
6104 text += (
6105 f" SELECT {colnames_w_cast} FROM "
6106 f"(VALUES ({insert_single_values_expr})) "
6107 f"AS imp_sen({colnames}, sen_counter) "
6108 "ORDER BY sen_counter"
6109 )
6110 else:
6111 # otherwise, if no sentinel or backend doesn't support
6112 # orderable subquery form, use a plain VALUES list
6113 embed_sentinel_value = False
6114 text += f" VALUES ({insert_single_values_expr})"
6115
6116 self._insertmanyvalues = _InsertManyValues(
6117 is_default_expr=False,
6118 single_values_expr=insert_single_values_expr,
6119 insert_crud_params=cast(
6120 "List[crud._CrudParamElementStr]",
6121 crud_params_single,
6122 ),
6123 num_positional_params_counted=counted_bindparam,
6124 sort_by_parameter_order=(
6125 insert_stmt._sort_by_parameter_order
6126 ),
6127 includes_upsert_behaviors=(
6128 insert_stmt._post_values_clause is not None
6129 ),
6130 sentinel_columns=add_sentinel_cols,
6131 num_sentinel_columns=(
6132 len(add_sentinel_cols) if add_sentinel_cols else 0
6133 ),
6134 sentinel_param_keys=named_sentinel_params,
6135 implicit_sentinel=implicit_sentinel,
6136 embed_values_counter=embed_sentinel_value,
6137 )
6138
6139 else:
6140 text += f" VALUES ({insert_single_values_expr})"
6141
6142 if insert_stmt._post_values_clause is not None:
6143 post_values_clause = self.process(
6144 insert_stmt._post_values_clause, **kw
6145 )
6146 if post_values_clause:
6147 text += " " + post_values_clause
6148
6149 if returning_clause and not self.returning_precedes_values:
6150 text += " " + returning_clause
6151
6152 if self.ctes and not self.dialect.cte_follows_insert:
6153 nesting_level = len(self.stack) if not toplevel else None
6154 text = (
6155 self._render_cte_clause(
6156 nesting_level=nesting_level,
6157 include_following_stack=True,
6158 )
6159 + text
6160 )
6161
6162 self.stack.pop(-1)
6163
6164 return text
6165
6166 def update_limit_clause(self, update_stmt):
6167 """Provide a hook for MySQL to add LIMIT to the UPDATE"""
6168 return None
6169
6170 def delete_limit_clause(self, delete_stmt):
6171 """Provide a hook for MySQL to add LIMIT to the DELETE"""
6172 return None
6173
6174 def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
6175 """Provide a hook to override the initial table clause
6176 in an UPDATE statement.
6177
6178 MySQL overrides this.
6179
6180 """
6181 kw["asfrom"] = True
6182 return from_table._compiler_dispatch(self, iscrud=True, **kw)
6183
6184 def update_from_clause(
6185 self, update_stmt, from_table, extra_froms, from_hints, **kw
6186 ):
6187 """Provide a hook to override the generation of an
6188 UPDATE..FROM clause.
6189
6190 MySQL and MSSQL override this.
6191
6192 """
6193 raise NotImplementedError(
6194 "This backend does not support multiple-table "
6195 "criteria within UPDATE"
6196 )
6197
6198 def visit_update(
6199 self,
6200 update_stmt: Update,
6201 visiting_cte: Optional[CTE] = None,
6202 **kw: Any,
6203 ) -> str:
6204 compile_state = update_stmt._compile_state_factory( # type: ignore[call-arg] # noqa: E501
6205 update_stmt, self, **kw # type: ignore[arg-type]
6206 )
6207 if TYPE_CHECKING:
6208 assert isinstance(compile_state, UpdateDMLState)
6209 update_stmt = compile_state.statement # type: ignore[assignment]
6210
6211 if visiting_cte is not None:
6212 kw["visiting_cte"] = visiting_cte
6213 toplevel = False
6214 else:
6215 toplevel = not self.stack
6216
6217 if toplevel:
6218 self.isupdate = True
6219 if not self.dml_compile_state:
6220 self.dml_compile_state = compile_state
6221 if not self.compile_state:
6222 self.compile_state = compile_state
6223
6224 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
6225 from_linter = FromLinter({}, set())
6226 warn_linting = self.linting & WARN_LINTING
6227 if toplevel:
6228 self.from_linter = from_linter
6229 else:
6230 from_linter = None
6231 warn_linting = False
6232
6233 extra_froms = compile_state._extra_froms
6234 is_multitable = bool(extra_froms)
6235
6236 if is_multitable:
6237 # main table might be a JOIN
6238 main_froms = set(_from_objects(update_stmt.table))
6239 render_extra_froms = [
6240 f for f in extra_froms if f not in main_froms
6241 ]
6242 correlate_froms = main_froms.union(extra_froms)
6243 else:
6244 render_extra_froms = []
6245 correlate_froms = {update_stmt.table}
6246
6247 self.stack.append(
6248 {
6249 "correlate_froms": correlate_froms,
6250 "asfrom_froms": correlate_froms,
6251 "selectable": update_stmt,
6252 }
6253 )
6254
6255 text = "UPDATE "
6256
6257 if update_stmt._prefixes:
6258 text += self._generate_prefixes(
6259 update_stmt, update_stmt._prefixes, **kw
6260 )
6261
6262 table_text = self.update_tables_clause(
6263 update_stmt,
6264 update_stmt.table,
6265 render_extra_froms,
6266 from_linter=from_linter,
6267 **kw,
6268 )
6269 crud_params_struct = crud._get_crud_params(
6270 self, update_stmt, compile_state, toplevel, **kw
6271 )
6272 crud_params = crud_params_struct.single_params
6273
6274 if update_stmt._hints:
6275 dialect_hints, table_text = self._setup_crud_hints(
6276 update_stmt, table_text
6277 )
6278 else:
6279 dialect_hints = None
6280
6281 if update_stmt._independent_ctes:
6282 self._dispatch_independent_ctes(update_stmt, kw)
6283
6284 text += table_text
6285
6286 text += " SET "
6287 text += ", ".join(
6288 expr + "=" + value
6289 for _, expr, value, _ in cast(
6290 "List[Tuple[Any, str, str, Any]]", crud_params
6291 )
6292 )
6293
6294 if self.implicit_returning or update_stmt._returning:
6295 if self.returning_precedes_values:
6296 text += " " + self.returning_clause(
6297 update_stmt,
6298 self.implicit_returning or update_stmt._returning,
6299 populate_result_map=toplevel,
6300 )
6301
6302 if extra_froms:
6303 extra_from_text = self.update_from_clause(
6304 update_stmt,
6305 update_stmt.table,
6306 render_extra_froms,
6307 dialect_hints,
6308 from_linter=from_linter,
6309 **kw,
6310 )
6311 if extra_from_text:
6312 text += " " + extra_from_text
6313
6314 if update_stmt._where_criteria:
6315 t = self._generate_delimited_and_list(
6316 update_stmt._where_criteria, from_linter=from_linter, **kw
6317 )
6318 if t:
6319 text += " WHERE " + t
6320
6321 limit_clause = self.update_limit_clause(update_stmt)
6322 if limit_clause:
6323 text += " " + limit_clause
6324
6325 if (
6326 self.implicit_returning or update_stmt._returning
6327 ) and not self.returning_precedes_values:
6328 text += " " + self.returning_clause(
6329 update_stmt,
6330 self.implicit_returning or update_stmt._returning,
6331 populate_result_map=toplevel,
6332 )
6333
6334 if self.ctes:
6335 nesting_level = len(self.stack) if not toplevel else None
6336 text = self._render_cte_clause(nesting_level=nesting_level) + text
6337
6338 if warn_linting:
6339 assert from_linter is not None
6340 from_linter.warn(stmt_type="UPDATE")
6341
6342 self.stack.pop(-1)
6343
6344 return text
6345
6346 def delete_extra_from_clause(
6347 self, delete_stmt, from_table, extra_froms, from_hints, **kw
6348 ):
6349 """Provide a hook to override the generation of an
6350 DELETE..FROM clause.
6351
6352 This can be used to implement DELETE..USING for example.
6353
6354 MySQL and MSSQL override this.
6355
6356 """
6357 raise NotImplementedError(
6358 "This backend does not support multiple-table "
6359 "criteria within DELETE"
6360 )
6361
6362 def delete_table_clause(self, delete_stmt, from_table, extra_froms, **kw):
6363 return from_table._compiler_dispatch(
6364 self, asfrom=True, iscrud=True, **kw
6365 )
6366
6367 def visit_delete(self, delete_stmt, visiting_cte=None, **kw):
6368 compile_state = delete_stmt._compile_state_factory(
6369 delete_stmt, self, **kw
6370 )
6371 delete_stmt = compile_state.statement
6372
6373 if visiting_cte is not None:
6374 kw["visiting_cte"] = visiting_cte
6375 toplevel = False
6376 else:
6377 toplevel = not self.stack
6378
6379 if toplevel:
6380 self.isdelete = True
6381 if not self.dml_compile_state:
6382 self.dml_compile_state = compile_state
6383 if not self.compile_state:
6384 self.compile_state = compile_state
6385
6386 if self.linting & COLLECT_CARTESIAN_PRODUCTS:
6387 from_linter = FromLinter({}, set())
6388 warn_linting = self.linting & WARN_LINTING
6389 if toplevel:
6390 self.from_linter = from_linter
6391 else:
6392 from_linter = None
6393 warn_linting = False
6394
6395 extra_froms = compile_state._extra_froms
6396
6397 correlate_froms = {delete_stmt.table}.union(extra_froms)
6398 self.stack.append(
6399 {
6400 "correlate_froms": correlate_froms,
6401 "asfrom_froms": correlate_froms,
6402 "selectable": delete_stmt,
6403 }
6404 )
6405
6406 text = "DELETE "
6407
6408 if delete_stmt._prefixes:
6409 text += self._generate_prefixes(
6410 delete_stmt, delete_stmt._prefixes, **kw
6411 )
6412
6413 text += "FROM "
6414
6415 try:
6416 table_text = self.delete_table_clause(
6417 delete_stmt,
6418 delete_stmt.table,
6419 extra_froms,
6420 from_linter=from_linter,
6421 )
6422 except TypeError:
6423 # anticipate 3rd party dialects that don't include **kw
6424 # TODO: remove in 2.1
6425 table_text = self.delete_table_clause(
6426 delete_stmt, delete_stmt.table, extra_froms
6427 )
6428 if from_linter:
6429 _ = self.process(delete_stmt.table, from_linter=from_linter)
6430
6431 crud._get_crud_params(self, delete_stmt, compile_state, toplevel, **kw)
6432
6433 if delete_stmt._hints:
6434 dialect_hints, table_text = self._setup_crud_hints(
6435 delete_stmt, table_text
6436 )
6437 else:
6438 dialect_hints = None
6439
6440 if delete_stmt._independent_ctes:
6441 self._dispatch_independent_ctes(delete_stmt, kw)
6442
6443 text += table_text
6444
6445 if (
6446 self.implicit_returning or delete_stmt._returning
6447 ) and self.returning_precedes_values:
6448 text += " " + self.returning_clause(
6449 delete_stmt,
6450 self.implicit_returning or delete_stmt._returning,
6451 populate_result_map=toplevel,
6452 )
6453
6454 if extra_froms:
6455 extra_from_text = self.delete_extra_from_clause(
6456 delete_stmt,
6457 delete_stmt.table,
6458 extra_froms,
6459 dialect_hints,
6460 from_linter=from_linter,
6461 **kw,
6462 )
6463 if extra_from_text:
6464 text += " " + extra_from_text
6465
6466 if delete_stmt._where_criteria:
6467 t = self._generate_delimited_and_list(
6468 delete_stmt._where_criteria, from_linter=from_linter, **kw
6469 )
6470 if t:
6471 text += " WHERE " + t
6472
6473 limit_clause = self.delete_limit_clause(delete_stmt)
6474 if limit_clause:
6475 text += " " + limit_clause
6476
6477 if (
6478 self.implicit_returning or delete_stmt._returning
6479 ) and not self.returning_precedes_values:
6480 text += " " + self.returning_clause(
6481 delete_stmt,
6482 self.implicit_returning or delete_stmt._returning,
6483 populate_result_map=toplevel,
6484 )
6485
6486 if self.ctes:
6487 nesting_level = len(self.stack) if not toplevel else None
6488 text = self._render_cte_clause(nesting_level=nesting_level) + text
6489
6490 if warn_linting:
6491 assert from_linter is not None
6492 from_linter.warn(stmt_type="DELETE")
6493
6494 self.stack.pop(-1)
6495
6496 return text
6497
6498 def visit_savepoint(self, savepoint_stmt, **kw):
6499 return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
6500
6501 def visit_rollback_to_savepoint(self, savepoint_stmt, **kw):
6502 return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
6503 savepoint_stmt
6504 )
6505
6506 def visit_release_savepoint(self, savepoint_stmt, **kw):
6507 return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
6508 savepoint_stmt
6509 )
6510
6511
6512class StrSQLCompiler(SQLCompiler):
6513 """A :class:`.SQLCompiler` subclass which allows a small selection
6514 of non-standard SQL features to render into a string value.
6515
6516 The :class:`.StrSQLCompiler` is invoked whenever a Core expression
6517 element is directly stringified without calling upon the
6518 :meth:`_expression.ClauseElement.compile` method.
6519 It can render a limited set
6520 of non-standard SQL constructs to assist in basic stringification,
6521 however for more substantial custom or dialect-specific SQL constructs,
6522 it will be necessary to make use of
6523 :meth:`_expression.ClauseElement.compile`
6524 directly.
6525
6526 .. seealso::
6527
6528 :ref:`faq_sql_expression_string`
6529
6530 """
6531
6532 def _fallback_column_name(self, column):
6533 return "<name unknown>"
6534
6535 @util.preload_module("sqlalchemy.engine.url")
6536 def visit_unsupported_compilation(self, element, err, **kw):
6537 if element.stringify_dialect != "default":
6538 url = util.preloaded.engine_url
6539 dialect = url.URL.create(element.stringify_dialect).get_dialect()()
6540
6541 compiler = dialect.statement_compiler(
6542 dialect, None, _supporting_against=self
6543 )
6544 if not isinstance(compiler, StrSQLCompiler):
6545 return compiler.process(element, **kw)
6546
6547 return super().visit_unsupported_compilation(element, err)
6548
6549 def visit_getitem_binary(self, binary, operator, **kw):
6550 return "%s[%s]" % (
6551 self.process(binary.left, **kw),
6552 self.process(binary.right, **kw),
6553 )
6554
6555 def visit_json_getitem_op_binary(self, binary, operator, **kw):
6556 return self.visit_getitem_binary(binary, operator, **kw)
6557
6558 def visit_json_path_getitem_op_binary(self, binary, operator, **kw):
6559 return self.visit_getitem_binary(binary, operator, **kw)
6560
6561 def visit_sequence(self, sequence, **kw):
6562 return (
6563 f"<next sequence value: {self.preparer.format_sequence(sequence)}>"
6564 )
6565
6566 def returning_clause(
6567 self,
6568 stmt: UpdateBase,
6569 returning_cols: Sequence[_ColumnsClauseElement],
6570 *,
6571 populate_result_map: bool,
6572 **kw: Any,
6573 ) -> str:
6574 columns = [
6575 self._label_select_column(None, c, True, False, {})
6576 for c in base._select_iterables(returning_cols)
6577 ]
6578 return "RETURNING " + ", ".join(columns)
6579
6580 def update_from_clause(
6581 self, update_stmt, from_table, extra_froms, from_hints, **kw
6582 ):
6583 kw["asfrom"] = True
6584 return "FROM " + ", ".join(
6585 t._compiler_dispatch(self, fromhints=from_hints, **kw)
6586 for t in extra_froms
6587 )
6588
6589 def delete_extra_from_clause(
6590 self, delete_stmt, from_table, extra_froms, from_hints, **kw
6591 ):
6592 kw["asfrom"] = True
6593 return ", " + ", ".join(
6594 t._compiler_dispatch(self, fromhints=from_hints, **kw)
6595 for t in extra_froms
6596 )
6597
6598 def visit_empty_set_expr(self, element_types, **kw):
6599 return "SELECT 1 WHERE 1!=1"
6600
6601 def get_from_hint_text(self, table, text):
6602 return "[%s]" % text
6603
6604 def visit_regexp_match_op_binary(self, binary, operator, **kw):
6605 return self._generate_generic_binary(binary, " <regexp> ", **kw)
6606
6607 def visit_not_regexp_match_op_binary(self, binary, operator, **kw):
6608 return self._generate_generic_binary(binary, " <not regexp> ", **kw)
6609
6610 def visit_regexp_replace_op_binary(self, binary, operator, **kw):
6611 return "<regexp replace>(%s, %s)" % (
6612 binary.left._compiler_dispatch(self, **kw),
6613 binary.right._compiler_dispatch(self, **kw),
6614 )
6615
6616 def visit_try_cast(self, cast, **kwargs):
6617 return "TRY_CAST(%s AS %s)" % (
6618 cast.clause._compiler_dispatch(self, **kwargs),
6619 cast.typeclause._compiler_dispatch(self, **kwargs),
6620 )
6621
6622
6623class DDLCompiler(Compiled):
6624 is_ddl = True
6625
6626 if TYPE_CHECKING:
6627
6628 def __init__(
6629 self,
6630 dialect: Dialect,
6631 statement: ExecutableDDLElement,
6632 schema_translate_map: Optional[SchemaTranslateMapType] = ...,
6633 render_schema_translate: bool = ...,
6634 compile_kwargs: Mapping[str, Any] = ...,
6635 ): ...
6636
6637 @util.ro_memoized_property
6638 def sql_compiler(self) -> SQLCompiler:
6639 return self.dialect.statement_compiler(
6640 self.dialect, None, schema_translate_map=self.schema_translate_map
6641 )
6642
6643 @util.memoized_property
6644 def type_compiler(self):
6645 return self.dialect.type_compiler_instance
6646
6647 def construct_params(
6648 self,
6649 params: Optional[_CoreSingleExecuteParams] = None,
6650 extracted_parameters: Optional[Sequence[BindParameter[Any]]] = None,
6651 escape_names: bool = True,
6652 ) -> Optional[_MutableCoreSingleExecuteParams]:
6653 return None
6654
6655 def visit_ddl(self, ddl, **kwargs):
6656 # table events can substitute table and schema name
6657 context = ddl.context
6658 if isinstance(ddl.target, schema.Table):
6659 context = context.copy()
6660
6661 preparer = self.preparer
6662 path = preparer.format_table_seq(ddl.target)
6663 if len(path) == 1:
6664 table, sch = path[0], ""
6665 else:
6666 table, sch = path[-1], path[0]
6667
6668 context.setdefault("table", table)
6669 context.setdefault("schema", sch)
6670 context.setdefault("fullname", preparer.format_table(ddl.target))
6671
6672 return self.sql_compiler.post_process_text(ddl.statement % context)
6673
6674 def visit_create_schema(self, create, **kw):
6675 text = "CREATE SCHEMA "
6676 if create.if_not_exists:
6677 text += "IF NOT EXISTS "
6678 return text + self.preparer.format_schema(create.element)
6679
6680 def visit_drop_schema(self, drop, **kw):
6681 text = "DROP SCHEMA "
6682 if drop.if_exists:
6683 text += "IF EXISTS "
6684 text += self.preparer.format_schema(drop.element)
6685 if drop.cascade:
6686 text += " CASCADE"
6687 return text
6688
6689 def visit_create_table(self, create, **kw):
6690 table = create.element
6691 preparer = self.preparer
6692
6693 text = "\nCREATE "
6694 if table._prefixes:
6695 text += " ".join(table._prefixes) + " "
6696
6697 text += "TABLE "
6698 if create.if_not_exists:
6699 text += "IF NOT EXISTS "
6700
6701 text += preparer.format_table(table) + " "
6702
6703 create_table_suffix = self.create_table_suffix(table)
6704 if create_table_suffix:
6705 text += create_table_suffix + " "
6706
6707 text += "("
6708
6709 separator = "\n"
6710
6711 # if only one primary key, specify it along with the column
6712 first_pk = False
6713 for create_column in create.columns:
6714 column = create_column.element
6715 try:
6716 processed = self.process(
6717 create_column, first_pk=column.primary_key and not first_pk
6718 )
6719 if processed is not None:
6720 text += separator
6721 separator = ", \n"
6722 text += "\t" + processed
6723 if column.primary_key:
6724 first_pk = True
6725 except exc.CompileError as ce:
6726 raise exc.CompileError(
6727 "(in table '%s', column '%s'): %s"
6728 % (table.description, column.name, ce.args[0])
6729 ) from ce
6730
6731 const = self.create_table_constraints(
6732 table,
6733 _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
6734 )
6735 if const:
6736 text += separator + "\t" + const
6737
6738 text += "\n)%s\n\n" % self.post_create_table(table)
6739 return text
6740
6741 def visit_create_column(self, create, first_pk=False, **kw):
6742 column = create.element
6743
6744 if column.system:
6745 return None
6746
6747 text = self.get_column_specification(column, first_pk=first_pk)
6748 const = " ".join(
6749 self.process(constraint) for constraint in column.constraints
6750 )
6751 if const:
6752 text += " " + const
6753
6754 return text
6755
6756 def create_table_constraints(
6757 self, table, _include_foreign_key_constraints=None, **kw
6758 ):
6759 # On some DB order is significant: visit PK first, then the
6760 # other constraints (engine.ReflectionTest.testbasic failed on FB2)
6761 constraints = []
6762 if table.primary_key:
6763 constraints.append(table.primary_key)
6764
6765 all_fkcs = table.foreign_key_constraints
6766 if _include_foreign_key_constraints is not None:
6767 omit_fkcs = all_fkcs.difference(_include_foreign_key_constraints)
6768 else:
6769 omit_fkcs = set()
6770
6771 constraints.extend(
6772 [
6773 c
6774 for c in table._sorted_constraints
6775 if c is not table.primary_key and c not in omit_fkcs
6776 ]
6777 )
6778
6779 return ", \n\t".join(
6780 p
6781 for p in (
6782 self.process(constraint)
6783 for constraint in constraints
6784 if (constraint._should_create_for_compiler(self))
6785 and (
6786 not self.dialect.supports_alter
6787 or not getattr(constraint, "use_alter", False)
6788 )
6789 )
6790 if p is not None
6791 )
6792
6793 def visit_drop_table(self, drop, **kw):
6794 text = "\nDROP TABLE "
6795 if drop.if_exists:
6796 text += "IF EXISTS "
6797 return text + self.preparer.format_table(drop.element)
6798
6799 def visit_drop_view(self, drop, **kw):
6800 return "\nDROP VIEW " + self.preparer.format_table(drop.element)
6801
6802 def _verify_index_table(self, index: Index) -> None:
6803 if index.table is None:
6804 raise exc.CompileError(
6805 "Index '%s' is not associated with any table." % index.name
6806 )
6807
6808 def visit_create_index(
6809 self, create, include_schema=False, include_table_schema=True, **kw
6810 ):
6811 index = create.element
6812 self._verify_index_table(index)
6813 preparer = self.preparer
6814 text = "CREATE "
6815 if index.unique:
6816 text += "UNIQUE "
6817 if index.name is None:
6818 raise exc.CompileError(
6819 "CREATE INDEX requires that the index have a name"
6820 )
6821
6822 text += "INDEX "
6823 if create.if_not_exists:
6824 text += "IF NOT EXISTS "
6825
6826 text += "%s ON %s (%s)" % (
6827 self._prepared_index_name(index, include_schema=include_schema),
6828 preparer.format_table(
6829 index.table, use_schema=include_table_schema
6830 ),
6831 ", ".join(
6832 self.sql_compiler.process(
6833 expr, include_table=False, literal_binds=True
6834 )
6835 for expr in index.expressions
6836 ),
6837 )
6838 return text
6839
6840 def visit_drop_index(self, drop, **kw):
6841 index = drop.element
6842
6843 if index.name is None:
6844 raise exc.CompileError(
6845 "DROP INDEX requires that the index have a name"
6846 )
6847 text = "\nDROP INDEX "
6848 if drop.if_exists:
6849 text += "IF EXISTS "
6850
6851 return text + self._prepared_index_name(index, include_schema=True)
6852
6853 def _prepared_index_name(
6854 self, index: Index, include_schema: bool = False
6855 ) -> str:
6856 if index.table is not None:
6857 effective_schema = self.preparer.schema_for_object(index.table)
6858 else:
6859 effective_schema = None
6860 if include_schema and effective_schema:
6861 schema_name = self.preparer.quote_schema(effective_schema)
6862 else:
6863 schema_name = None
6864
6865 index_name = self.preparer.format_index(index)
6866
6867 if schema_name:
6868 index_name = schema_name + "." + index_name
6869 return index_name
6870
6871 def visit_add_constraint(self, create, **kw):
6872 return "ALTER TABLE %s ADD %s" % (
6873 self.preparer.format_table(create.element.table),
6874 self.process(create.element),
6875 )
6876
6877 def visit_set_table_comment(self, create, **kw):
6878 return "COMMENT ON TABLE %s IS %s" % (
6879 self.preparer.format_table(create.element),
6880 self.sql_compiler.render_literal_value(
6881 create.element.comment, sqltypes.String()
6882 ),
6883 )
6884
6885 def visit_drop_table_comment(self, drop, **kw):
6886 return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
6887 drop.element
6888 )
6889
6890 def visit_set_column_comment(self, create, **kw):
6891 return "COMMENT ON COLUMN %s IS %s" % (
6892 self.preparer.format_column(
6893 create.element, use_table=True, use_schema=True
6894 ),
6895 self.sql_compiler.render_literal_value(
6896 create.element.comment, sqltypes.String()
6897 ),
6898 )
6899
6900 def visit_drop_column_comment(self, drop, **kw):
6901 return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
6902 drop.element, use_table=True
6903 )
6904
6905 def visit_set_constraint_comment(self, create, **kw):
6906 raise exc.UnsupportedCompilationError(self, type(create))
6907
6908 def visit_drop_constraint_comment(self, drop, **kw):
6909 raise exc.UnsupportedCompilationError(self, type(drop))
6910
6911 def get_identity_options(self, identity_options):
6912 text = []
6913 if identity_options.increment is not None:
6914 text.append("INCREMENT BY %d" % identity_options.increment)
6915 if identity_options.start is not None:
6916 text.append("START WITH %d" % identity_options.start)
6917 if identity_options.minvalue is not None:
6918 text.append("MINVALUE %d" % identity_options.minvalue)
6919 if identity_options.maxvalue is not None:
6920 text.append("MAXVALUE %d" % identity_options.maxvalue)
6921 if identity_options.nominvalue is not None:
6922 text.append("NO MINVALUE")
6923 if identity_options.nomaxvalue is not None:
6924 text.append("NO MAXVALUE")
6925 if identity_options.cache is not None:
6926 text.append("CACHE %d" % identity_options.cache)
6927 if identity_options.cycle is not None:
6928 text.append("CYCLE" if identity_options.cycle else "NO CYCLE")
6929 return " ".join(text)
6930
6931 def visit_create_sequence(self, create, prefix=None, **kw):
6932 text = "CREATE SEQUENCE "
6933 if create.if_not_exists:
6934 text += "IF NOT EXISTS "
6935 text += self.preparer.format_sequence(create.element)
6936
6937 if prefix:
6938 text += prefix
6939 options = self.get_identity_options(create.element)
6940 if options:
6941 text += " " + options
6942 return text
6943
6944 def visit_drop_sequence(self, drop, **kw):
6945 text = "DROP SEQUENCE "
6946 if drop.if_exists:
6947 text += "IF EXISTS "
6948 return text + self.preparer.format_sequence(drop.element)
6949
6950 def visit_drop_constraint(self, drop, **kw):
6951 constraint = drop.element
6952 if constraint.name is not None:
6953 formatted_name = self.preparer.format_constraint(constraint)
6954 else:
6955 formatted_name = None
6956
6957 if formatted_name is None:
6958 raise exc.CompileError(
6959 "Can't emit DROP CONSTRAINT for constraint %r; "
6960 "it has no name" % drop.element
6961 )
6962 return "ALTER TABLE %s DROP CONSTRAINT %s%s%s" % (
6963 self.preparer.format_table(drop.element.table),
6964 "IF EXISTS " if drop.if_exists else "",
6965 formatted_name,
6966 " CASCADE" if drop.cascade else "",
6967 )
6968
6969 def get_column_specification(self, column, **kwargs):
6970 colspec = (
6971 self.preparer.format_column(column)
6972 + " "
6973 + self.dialect.type_compiler_instance.process(
6974 column.type, type_expression=column
6975 )
6976 )
6977 default = self.get_column_default_string(column)
6978 if default is not None:
6979 colspec += " DEFAULT " + default
6980
6981 if column.computed is not None:
6982 colspec += " " + self.process(column.computed)
6983
6984 if (
6985 column.identity is not None
6986 and self.dialect.supports_identity_columns
6987 ):
6988 colspec += " " + self.process(column.identity)
6989
6990 if not column.nullable and (
6991 not column.identity or not self.dialect.supports_identity_columns
6992 ):
6993 colspec += " NOT NULL"
6994 return colspec
6995
6996 def create_table_suffix(self, table):
6997 return ""
6998
6999 def post_create_table(self, table):
7000 return ""
7001
7002 def get_column_default_string(self, column: Column[Any]) -> Optional[str]:
7003 if isinstance(column.server_default, schema.DefaultClause):
7004 return self.render_default_string(column.server_default.arg)
7005 else:
7006 return None
7007
7008 def render_default_string(self, default: Union[Visitable, str]) -> str:
7009 if isinstance(default, str):
7010 return self.sql_compiler.render_literal_value(
7011 default, sqltypes.STRINGTYPE
7012 )
7013 else:
7014 return self.sql_compiler.process(default, literal_binds=True)
7015
7016 def visit_table_or_column_check_constraint(self, constraint, **kw):
7017 if constraint.is_column_level:
7018 return self.visit_column_check_constraint(constraint)
7019 else:
7020 return self.visit_check_constraint(constraint)
7021
7022 def visit_check_constraint(self, constraint, **kw):
7023 text = ""
7024 if constraint.name is not None:
7025 formatted_name = self.preparer.format_constraint(constraint)
7026 if formatted_name is not None:
7027 text += "CONSTRAINT %s " % formatted_name
7028 text += "CHECK (%s)" % self.sql_compiler.process(
7029 constraint.sqltext, include_table=False, literal_binds=True
7030 )
7031 text += self.define_constraint_deferrability(constraint)
7032 return text
7033
7034 def visit_column_check_constraint(self, constraint, **kw):
7035 text = ""
7036 if constraint.name is not None:
7037 formatted_name = self.preparer.format_constraint(constraint)
7038 if formatted_name is not None:
7039 text += "CONSTRAINT %s " % formatted_name
7040 text += "CHECK (%s)" % self.sql_compiler.process(
7041 constraint.sqltext, include_table=False, literal_binds=True
7042 )
7043 text += self.define_constraint_deferrability(constraint)
7044 return text
7045
7046 def visit_primary_key_constraint(
7047 self, constraint: PrimaryKeyConstraint, **kw: Any
7048 ) -> str:
7049 if len(constraint) == 0:
7050 return ""
7051 text = ""
7052 if constraint.name is not None:
7053 formatted_name = self.preparer.format_constraint(constraint)
7054 if formatted_name is not None:
7055 text += "CONSTRAINT %s " % formatted_name
7056 text += "PRIMARY KEY "
7057 text += "(%s)" % ", ".join(
7058 self.preparer.quote(c.name)
7059 for c in (
7060 constraint.columns_autoinc_first
7061 if constraint._implicit_generated
7062 else constraint.columns
7063 )
7064 )
7065 text += self.define_constraint_deferrability(constraint)
7066 return text
7067
7068 def visit_foreign_key_constraint(self, constraint, **kw):
7069 preparer = self.preparer
7070 text = ""
7071 if constraint.name is not None:
7072 formatted_name = self.preparer.format_constraint(constraint)
7073 if formatted_name is not None:
7074 text += "CONSTRAINT %s " % formatted_name
7075 remote_table = list(constraint.elements)[0].column.table
7076 text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
7077 ", ".join(
7078 preparer.quote(f.parent.name) for f in constraint.elements
7079 ),
7080 self.define_constraint_remote_table(
7081 constraint, remote_table, preparer
7082 ),
7083 ", ".join(
7084 preparer.quote(f.column.name) for f in constraint.elements
7085 ),
7086 )
7087 text += self.define_constraint_match(constraint)
7088 text += self.define_constraint_cascades(constraint)
7089 text += self.define_constraint_deferrability(constraint)
7090 return text
7091
7092 def define_constraint_remote_table(self, constraint, table, preparer):
7093 """Format the remote table clause of a CREATE CONSTRAINT clause."""
7094
7095 return preparer.format_table(table)
7096
7097 def visit_unique_constraint(
7098 self, constraint: UniqueConstraint, **kw: Any
7099 ) -> str:
7100 if len(constraint) == 0:
7101 return ""
7102 text = ""
7103 if constraint.name is not None:
7104 formatted_name = self.preparer.format_constraint(constraint)
7105 if formatted_name is not None:
7106 text += "CONSTRAINT %s " % formatted_name
7107 text += "UNIQUE %s(%s)" % (
7108 self.define_unique_constraint_distinct(constraint, **kw),
7109 ", ".join(self.preparer.quote(c.name) for c in constraint),
7110 )
7111 text += self.define_constraint_deferrability(constraint)
7112 return text
7113
7114 def define_unique_constraint_distinct(
7115 self, constraint: UniqueConstraint, **kw: Any
7116 ) -> str:
7117 return ""
7118
7119 def define_constraint_cascades(
7120 self, constraint: ForeignKeyConstraint
7121 ) -> str:
7122 text = ""
7123 if constraint.ondelete is not None:
7124 text += self.define_constraint_ondelete_cascade(constraint)
7125
7126 if constraint.onupdate is not None:
7127 text += self.define_constraint_onupdate_cascade(constraint)
7128 return text
7129
7130 def define_constraint_ondelete_cascade(
7131 self, constraint: ForeignKeyConstraint
7132 ) -> str:
7133 return " ON DELETE %s" % self.preparer.validate_sql_phrase(
7134 constraint.ondelete, FK_ON_DELETE
7135 )
7136
7137 def define_constraint_onupdate_cascade(
7138 self, constraint: ForeignKeyConstraint
7139 ) -> str:
7140 return " ON UPDATE %s" % self.preparer.validate_sql_phrase(
7141 constraint.onupdate, FK_ON_UPDATE
7142 )
7143
7144 def define_constraint_deferrability(self, constraint: Constraint) -> str:
7145 text = ""
7146 if constraint.deferrable is not None:
7147 if constraint.deferrable:
7148 text += " DEFERRABLE"
7149 else:
7150 text += " NOT DEFERRABLE"
7151 if constraint.initially is not None:
7152 text += " INITIALLY %s" % self.preparer.validate_sql_phrase(
7153 constraint.initially, FK_INITIALLY
7154 )
7155 return text
7156
7157 def define_constraint_match(self, constraint):
7158 text = ""
7159 if constraint.match is not None:
7160 text += " MATCH %s" % constraint.match
7161 return text
7162
7163 def visit_computed_column(self, generated, **kw):
7164 text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
7165 generated.sqltext, include_table=False, literal_binds=True
7166 )
7167 if generated.persisted is True:
7168 text += " STORED"
7169 elif generated.persisted is False:
7170 text += " VIRTUAL"
7171 return text
7172
7173 def visit_identity_column(self, identity, **kw):
7174 text = "GENERATED %s AS IDENTITY" % (
7175 "ALWAYS" if identity.always else "BY DEFAULT",
7176 )
7177 options = self.get_identity_options(identity)
7178 if options:
7179 text += " (%s)" % options
7180 return text
7181
7182
7183class GenericTypeCompiler(TypeCompiler):
7184 def visit_FLOAT(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
7185 return "FLOAT"
7186
7187 def visit_DOUBLE(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
7188 return "DOUBLE"
7189
7190 def visit_DOUBLE_PRECISION(
7191 self, type_: sqltypes.DOUBLE_PRECISION[Any], **kw: Any
7192 ) -> str:
7193 return "DOUBLE PRECISION"
7194
7195 def visit_REAL(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
7196 return "REAL"
7197
7198 def visit_NUMERIC(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
7199 if type_.precision is None:
7200 return "NUMERIC"
7201 elif type_.scale is None:
7202 return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
7203 else:
7204 return "NUMERIC(%(precision)s, %(scale)s)" % {
7205 "precision": type_.precision,
7206 "scale": type_.scale,
7207 }
7208
7209 def visit_DECIMAL(self, type_: sqltypes.DECIMAL[Any], **kw: Any) -> str:
7210 if type_.precision is None:
7211 return "DECIMAL"
7212 elif type_.scale is None:
7213 return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
7214 else:
7215 return "DECIMAL(%(precision)s, %(scale)s)" % {
7216 "precision": type_.precision,
7217 "scale": type_.scale,
7218 }
7219
7220 def visit_INTEGER(self, type_: sqltypes.Integer, **kw: Any) -> str:
7221 return "INTEGER"
7222
7223 def visit_SMALLINT(self, type_: sqltypes.SmallInteger, **kw: Any) -> str:
7224 return "SMALLINT"
7225
7226 def visit_BIGINT(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
7227 return "BIGINT"
7228
7229 def visit_TIMESTAMP(self, type_: sqltypes.TIMESTAMP, **kw: Any) -> str:
7230 return "TIMESTAMP"
7231
7232 def visit_DATETIME(self, type_: sqltypes.DateTime, **kw: Any) -> str:
7233 return "DATETIME"
7234
7235 def visit_DATE(self, type_: sqltypes.Date, **kw: Any) -> str:
7236 return "DATE"
7237
7238 def visit_TIME(self, type_: sqltypes.Time, **kw: Any) -> str:
7239 return "TIME"
7240
7241 def visit_CLOB(self, type_: sqltypes.CLOB, **kw: Any) -> str:
7242 return "CLOB"
7243
7244 def visit_NCLOB(self, type_: sqltypes.Text, **kw: Any) -> str:
7245 return "NCLOB"
7246
7247 def _render_string_type(
7248 self, name: str, length: Optional[int], collation: Optional[str]
7249 ) -> str:
7250 text = name
7251 if length:
7252 text += f"({length})"
7253 if collation:
7254 text += f' COLLATE "{collation}"'
7255 return text
7256
7257 def visit_CHAR(self, type_: sqltypes.CHAR, **kw: Any) -> str:
7258 return self._render_string_type("CHAR", type_.length, type_.collation)
7259
7260 def visit_NCHAR(self, type_: sqltypes.NCHAR, **kw: Any) -> str:
7261 return self._render_string_type("NCHAR", type_.length, type_.collation)
7262
7263 def visit_VARCHAR(self, type_: sqltypes.String, **kw: Any) -> str:
7264 return self._render_string_type(
7265 "VARCHAR", type_.length, type_.collation
7266 )
7267
7268 def visit_NVARCHAR(self, type_: sqltypes.NVARCHAR, **kw: Any) -> str:
7269 return self._render_string_type(
7270 "NVARCHAR", type_.length, type_.collation
7271 )
7272
7273 def visit_TEXT(self, type_: sqltypes.Text, **kw: Any) -> str:
7274 return self._render_string_type("TEXT", type_.length, type_.collation)
7275
7276 def visit_UUID(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
7277 return "UUID"
7278
7279 def visit_BLOB(self, type_: sqltypes.LargeBinary, **kw: Any) -> str:
7280 return "BLOB"
7281
7282 def visit_BINARY(self, type_: sqltypes.BINARY, **kw: Any) -> str:
7283 return "BINARY" + (type_.length and "(%d)" % type_.length or "")
7284
7285 def visit_VARBINARY(self, type_: sqltypes.VARBINARY, **kw: Any) -> str:
7286 return "VARBINARY" + (type_.length and "(%d)" % type_.length or "")
7287
7288 def visit_BOOLEAN(self, type_: sqltypes.Boolean, **kw: Any) -> str:
7289 return "BOOLEAN"
7290
7291 def visit_uuid(self, type_: sqltypes.Uuid[Any], **kw: Any) -> str:
7292 if not type_.native_uuid or not self.dialect.supports_native_uuid:
7293 return self._render_string_type("CHAR", length=32, collation=None)
7294 else:
7295 return self.visit_UUID(type_, **kw)
7296
7297 def visit_large_binary(
7298 self, type_: sqltypes.LargeBinary, **kw: Any
7299 ) -> str:
7300 return self.visit_BLOB(type_, **kw)
7301
7302 def visit_boolean(self, type_: sqltypes.Boolean, **kw: Any) -> str:
7303 return self.visit_BOOLEAN(type_, **kw)
7304
7305 def visit_time(self, type_: sqltypes.Time, **kw: Any) -> str:
7306 return self.visit_TIME(type_, **kw)
7307
7308 def visit_datetime(self, type_: sqltypes.DateTime, **kw: Any) -> str:
7309 return self.visit_DATETIME(type_, **kw)
7310
7311 def visit_date(self, type_: sqltypes.Date, **kw: Any) -> str:
7312 return self.visit_DATE(type_, **kw)
7313
7314 def visit_big_integer(self, type_: sqltypes.BigInteger, **kw: Any) -> str:
7315 return self.visit_BIGINT(type_, **kw)
7316
7317 def visit_small_integer(
7318 self, type_: sqltypes.SmallInteger, **kw: Any
7319 ) -> str:
7320 return self.visit_SMALLINT(type_, **kw)
7321
7322 def visit_integer(self, type_: sqltypes.Integer, **kw: Any) -> str:
7323 return self.visit_INTEGER(type_, **kw)
7324
7325 def visit_real(self, type_: sqltypes.REAL[Any], **kw: Any) -> str:
7326 return self.visit_REAL(type_, **kw)
7327
7328 def visit_float(self, type_: sqltypes.Float[Any], **kw: Any) -> str:
7329 return self.visit_FLOAT(type_, **kw)
7330
7331 def visit_double(self, type_: sqltypes.Double[Any], **kw: Any) -> str:
7332 return self.visit_DOUBLE(type_, **kw)
7333
7334 def visit_numeric(self, type_: sqltypes.Numeric[Any], **kw: Any) -> str:
7335 return self.visit_NUMERIC(type_, **kw)
7336
7337 def visit_string(self, type_: sqltypes.String, **kw: Any) -> str:
7338 return self.visit_VARCHAR(type_, **kw)
7339
7340 def visit_unicode(self, type_: sqltypes.Unicode, **kw: Any) -> str:
7341 return self.visit_VARCHAR(type_, **kw)
7342
7343 def visit_text(self, type_: sqltypes.Text, **kw: Any) -> str:
7344 return self.visit_TEXT(type_, **kw)
7345
7346 def visit_unicode_text(
7347 self, type_: sqltypes.UnicodeText, **kw: Any
7348 ) -> str:
7349 return self.visit_TEXT(type_, **kw)
7350
7351 def visit_enum(self, type_: sqltypes.Enum, **kw: Any) -> str:
7352 return self.visit_VARCHAR(type_, **kw)
7353
7354 def visit_null(self, type_, **kw):
7355 raise exc.CompileError(
7356 "Can't generate DDL for %r; "
7357 "did you forget to specify a "
7358 "type on this Column?" % type_
7359 )
7360
7361 def visit_type_decorator(
7362 self, type_: TypeDecorator[Any], **kw: Any
7363 ) -> str:
7364 return self.process(type_.type_engine(self.dialect), **kw)
7365
7366 def visit_user_defined(
7367 self, type_: UserDefinedType[Any], **kw: Any
7368 ) -> str:
7369 return type_.get_col_spec(**kw)
7370
7371
7372class StrSQLTypeCompiler(GenericTypeCompiler):
7373 def process(self, type_, **kw):
7374 try:
7375 _compiler_dispatch = type_._compiler_dispatch
7376 except AttributeError:
7377 return self._visit_unknown(type_, **kw)
7378 else:
7379 return _compiler_dispatch(self, **kw)
7380
7381 def __getattr__(self, key):
7382 if key.startswith("visit_"):
7383 return self._visit_unknown
7384 else:
7385 raise AttributeError(key)
7386
7387 def _visit_unknown(self, type_, **kw):
7388 if type_.__class__.__name__ == type_.__class__.__name__.upper():
7389 return type_.__class__.__name__
7390 else:
7391 return repr(type_)
7392
7393 def visit_null(self, type_, **kw):
7394 return "NULL"
7395
7396 def visit_user_defined(self, type_, **kw):
7397 try:
7398 get_col_spec = type_.get_col_spec
7399 except AttributeError:
7400 return repr(type_)
7401 else:
7402 return get_col_spec(**kw)
7403
7404
7405class _SchemaForObjectCallable(Protocol):
7406 def __call__(self, __obj: Any) -> str: ...
7407
7408
7409class _BindNameForColProtocol(Protocol):
7410 def __call__(self, col: ColumnClause[Any]) -> str: ...
7411
7412
7413class IdentifierPreparer:
7414 """Handle quoting and case-folding of identifiers based on options."""
7415
7416 reserved_words = RESERVED_WORDS
7417
7418 legal_characters = LEGAL_CHARACTERS
7419
7420 illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
7421
7422 initial_quote: str
7423
7424 final_quote: str
7425
7426 _strings: MutableMapping[str, str]
7427
7428 schema_for_object: _SchemaForObjectCallable = operator.attrgetter("schema")
7429 """Return the .schema attribute for an object.
7430
7431 For the default IdentifierPreparer, the schema for an object is always
7432 the value of the ".schema" attribute. if the preparer is replaced
7433 with one that has a non-empty schema_translate_map, the value of the
7434 ".schema" attribute is rendered a symbol that will be converted to a
7435 real schema name from the mapping post-compile.
7436
7437 """
7438
7439 _includes_none_schema_translate: bool = False
7440
7441 def __init__(
7442 self,
7443 dialect: Dialect,
7444 initial_quote: str = '"',
7445 final_quote: Optional[str] = None,
7446 escape_quote: str = '"',
7447 quote_case_sensitive_collations: bool = True,
7448 omit_schema: bool = False,
7449 ):
7450 """Construct a new ``IdentifierPreparer`` object.
7451
7452 initial_quote
7453 Character that begins a delimited identifier.
7454
7455 final_quote
7456 Character that ends a delimited identifier. Defaults to
7457 `initial_quote`.
7458
7459 omit_schema
7460 Prevent prepending schema name. Useful for databases that do
7461 not support schemae.
7462 """
7463
7464 self.dialect = dialect
7465 self.initial_quote = initial_quote
7466 self.final_quote = final_quote or self.initial_quote
7467 self.escape_quote = escape_quote
7468 self.escape_to_quote = self.escape_quote * 2
7469 self.omit_schema = omit_schema
7470 self.quote_case_sensitive_collations = quote_case_sensitive_collations
7471 self._strings = {}
7472 self._double_percents = self.dialect.paramstyle in (
7473 "format",
7474 "pyformat",
7475 )
7476
7477 def _with_schema_translate(self, schema_translate_map):
7478 prep = self.__class__.__new__(self.__class__)
7479 prep.__dict__.update(self.__dict__)
7480
7481 includes_none = None in schema_translate_map
7482
7483 def symbol_getter(obj):
7484 name = obj.schema
7485 if obj._use_schema_map and (name is not None or includes_none):
7486 if name is not None and ("[" in name or "]" in name):
7487 raise exc.CompileError(
7488 "Square bracket characters ([]) not supported "
7489 "in schema translate name '%s'" % name
7490 )
7491 return quoted_name(
7492 "__[SCHEMA_%s]" % (name or "_none"), quote=False
7493 )
7494 else:
7495 return obj.schema
7496
7497 prep.schema_for_object = symbol_getter
7498 prep._includes_none_schema_translate = includes_none
7499 return prep
7500
7501 def _render_schema_translates(
7502 self, statement: str, schema_translate_map: SchemaTranslateMapType
7503 ) -> str:
7504 d = schema_translate_map
7505 if None in d:
7506 if not self._includes_none_schema_translate:
7507 raise exc.InvalidRequestError(
7508 "schema translate map which previously did not have "
7509 "`None` present as a key now has `None` present; compiled "
7510 "statement may lack adequate placeholders. Please use "
7511 "consistent keys in successive "
7512 "schema_translate_map dictionaries."
7513 )
7514
7515 d["_none"] = d[None] # type: ignore[index]
7516
7517 def replace(m):
7518 name = m.group(2)
7519 if name in d:
7520 effective_schema = d[name]
7521 else:
7522 if name in (None, "_none"):
7523 raise exc.InvalidRequestError(
7524 "schema translate map which previously had `None` "
7525 "present as a key now no longer has it present; don't "
7526 "know how to apply schema for compiled statement. "
7527 "Please use consistent keys in successive "
7528 "schema_translate_map dictionaries."
7529 )
7530 effective_schema = name
7531
7532 if not effective_schema:
7533 effective_schema = self.dialect.default_schema_name
7534 if not effective_schema:
7535 # TODO: no coverage here
7536 raise exc.CompileError(
7537 "Dialect has no default schema name; can't "
7538 "use None as dynamic schema target."
7539 )
7540 return self.quote_schema(effective_schema)
7541
7542 return re.sub(r"(__\[SCHEMA_([^\]]+)\])", replace, statement)
7543
7544 def _escape_identifier(self, value: str) -> str:
7545 """Escape an identifier.
7546
7547 Subclasses should override this to provide database-dependent
7548 escaping behavior.
7549 """
7550
7551 value = value.replace(self.escape_quote, self.escape_to_quote)
7552 if self._double_percents:
7553 value = value.replace("%", "%%")
7554 return value
7555
7556 def _unescape_identifier(self, value: str) -> str:
7557 """Canonicalize an escaped identifier.
7558
7559 Subclasses should override this to provide database-dependent
7560 unescaping behavior that reverses _escape_identifier.
7561 """
7562
7563 return value.replace(self.escape_to_quote, self.escape_quote)
7564
7565 def validate_sql_phrase(self, element, reg):
7566 """keyword sequence filter.
7567
7568 a filter for elements that are intended to represent keyword sequences,
7569 such as "INITIALLY", "INITIALLY DEFERRED", etc. no special characters
7570 should be present.
7571
7572 .. versionadded:: 1.3
7573
7574 """
7575
7576 if element is not None and not reg.match(element):
7577 raise exc.CompileError(
7578 "Unexpected SQL phrase: %r (matching against %r)"
7579 % (element, reg.pattern)
7580 )
7581 return element
7582
7583 def quote_identifier(self, value: str) -> str:
7584 """Quote an identifier.
7585
7586 Subclasses should override this to provide database-dependent
7587 quoting behavior.
7588 """
7589
7590 return (
7591 self.initial_quote
7592 + self._escape_identifier(value)
7593 + self.final_quote
7594 )
7595
7596 def _requires_quotes(self, value: str) -> bool:
7597 """Return True if the given identifier requires quoting."""
7598 lc_value = value.lower()
7599 return (
7600 lc_value in self.reserved_words
7601 or value[0] in self.illegal_initial_characters
7602 or not self.legal_characters.match(str(value))
7603 or (lc_value != value)
7604 )
7605
7606 def _requires_quotes_illegal_chars(self, value):
7607 """Return True if the given identifier requires quoting, but
7608 not taking case convention into account."""
7609 return not self.legal_characters.match(str(value))
7610
7611 def quote_schema(self, schema: str, force: Any = None) -> str:
7612 """Conditionally quote a schema name.
7613
7614
7615 The name is quoted if it is a reserved word, contains quote-necessary
7616 characters, or is an instance of :class:`.quoted_name` which includes
7617 ``quote`` set to ``True``.
7618
7619 Subclasses can override this to provide database-dependent
7620 quoting behavior for schema names.
7621
7622 :param schema: string schema name
7623 :param force: unused
7624
7625 .. deprecated:: 0.9
7626
7627 The :paramref:`.IdentifierPreparer.quote_schema.force`
7628 parameter is deprecated and will be removed in a future
7629 release. This flag has no effect on the behavior of the
7630 :meth:`.IdentifierPreparer.quote` method; please refer to
7631 :class:`.quoted_name`.
7632
7633 """
7634 if force is not None:
7635 # not using the util.deprecated_params() decorator in this
7636 # case because of the additional function call overhead on this
7637 # very performance-critical spot.
7638 util.warn_deprecated(
7639 "The IdentifierPreparer.quote_schema.force parameter is "
7640 "deprecated and will be removed in a future release. This "
7641 "flag has no effect on the behavior of the "
7642 "IdentifierPreparer.quote method; please refer to "
7643 "quoted_name().",
7644 # deprecated 0.9. warning from 1.3
7645 version="0.9",
7646 )
7647
7648 return self.quote(schema)
7649
7650 def quote(self, ident: str, force: Any = None) -> str:
7651 """Conditionally quote an identifier.
7652
7653 The identifier is quoted if it is a reserved word, contains
7654 quote-necessary characters, or is an instance of
7655 :class:`.quoted_name` which includes ``quote`` set to ``True``.
7656
7657 Subclasses can override this to provide database-dependent
7658 quoting behavior for identifier names.
7659
7660 :param ident: string identifier
7661 :param force: unused
7662
7663 .. deprecated:: 0.9
7664
7665 The :paramref:`.IdentifierPreparer.quote.force`
7666 parameter is deprecated and will be removed in a future
7667 release. This flag has no effect on the behavior of the
7668 :meth:`.IdentifierPreparer.quote` method; please refer to
7669 :class:`.quoted_name`.
7670
7671 """
7672 if force is not None:
7673 # not using the util.deprecated_params() decorator in this
7674 # case because of the additional function call overhead on this
7675 # very performance-critical spot.
7676 util.warn_deprecated(
7677 "The IdentifierPreparer.quote.force parameter is "
7678 "deprecated and will be removed in a future release. This "
7679 "flag has no effect on the behavior of the "
7680 "IdentifierPreparer.quote method; please refer to "
7681 "quoted_name().",
7682 # deprecated 0.9. warning from 1.3
7683 version="0.9",
7684 )
7685
7686 force = getattr(ident, "quote", None)
7687
7688 if force is None:
7689 if ident in self._strings:
7690 return self._strings[ident]
7691 else:
7692 if self._requires_quotes(ident):
7693 self._strings[ident] = self.quote_identifier(ident)
7694 else:
7695 self._strings[ident] = ident
7696 return self._strings[ident]
7697 elif force:
7698 return self.quote_identifier(ident)
7699 else:
7700 return ident
7701
7702 def format_collation(self, collation_name):
7703 if self.quote_case_sensitive_collations:
7704 return self.quote(collation_name)
7705 else:
7706 return collation_name
7707
7708 def format_sequence(
7709 self, sequence: schema.Sequence, use_schema: bool = True
7710 ) -> str:
7711 name = self.quote(sequence.name)
7712
7713 effective_schema = self.schema_for_object(sequence)
7714
7715 if (
7716 not self.omit_schema
7717 and use_schema
7718 and effective_schema is not None
7719 ):
7720 name = self.quote_schema(effective_schema) + "." + name
7721 return name
7722
7723 def format_label(
7724 self, label: Label[Any], name: Optional[str] = None
7725 ) -> str:
7726 return self.quote(name or label.name)
7727
7728 def format_alias(
7729 self, alias: Optional[AliasedReturnsRows], name: Optional[str] = None
7730 ) -> str:
7731 if name is None:
7732 assert alias is not None
7733 return self.quote(alias.name)
7734 else:
7735 return self.quote(name)
7736
7737 def format_savepoint(self, savepoint, name=None):
7738 # Running the savepoint name through quoting is unnecessary
7739 # for all known dialects. This is here to support potential
7740 # third party use cases
7741 ident = name or savepoint.ident
7742 if self._requires_quotes(ident):
7743 ident = self.quote_identifier(ident)
7744 return ident
7745
7746 @util.preload_module("sqlalchemy.sql.naming")
7747 def format_constraint(
7748 self, constraint: Union[Constraint, Index], _alembic_quote: bool = True
7749 ) -> Optional[str]:
7750 naming = util.preloaded.sql_naming
7751
7752 if constraint.name is _NONE_NAME:
7753 name = naming._constraint_name_for_table(
7754 constraint, constraint.table
7755 )
7756
7757 if name is None:
7758 return None
7759 else:
7760 name = constraint.name
7761
7762 assert name is not None
7763 if constraint.__visit_name__ == "index":
7764 return self.truncate_and_render_index_name(
7765 name, _alembic_quote=_alembic_quote
7766 )
7767 else:
7768 return self.truncate_and_render_constraint_name(
7769 name, _alembic_quote=_alembic_quote
7770 )
7771
7772 def truncate_and_render_index_name(
7773 self, name: str, _alembic_quote: bool = True
7774 ) -> str:
7775 # calculate these at format time so that ad-hoc changes
7776 # to dialect.max_identifier_length etc. can be reflected
7777 # as IdentifierPreparer is long lived
7778 max_ = (
7779 self.dialect.max_index_name_length
7780 or self.dialect.max_identifier_length
7781 )
7782 return self._truncate_and_render_maxlen_name(
7783 name, max_, _alembic_quote
7784 )
7785
7786 def truncate_and_render_constraint_name(
7787 self, name: str, _alembic_quote: bool = True
7788 ) -> str:
7789 # calculate these at format time so that ad-hoc changes
7790 # to dialect.max_identifier_length etc. can be reflected
7791 # as IdentifierPreparer is long lived
7792 max_ = (
7793 self.dialect.max_constraint_name_length
7794 or self.dialect.max_identifier_length
7795 )
7796 return self._truncate_and_render_maxlen_name(
7797 name, max_, _alembic_quote
7798 )
7799
7800 def _truncate_and_render_maxlen_name(
7801 self, name: str, max_: int, _alembic_quote: bool
7802 ) -> str:
7803 if isinstance(name, elements._truncated_label):
7804 if len(name) > max_:
7805 name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
7806 else:
7807 self.dialect.validate_identifier(name)
7808
7809 if not _alembic_quote:
7810 return name
7811 else:
7812 return self.quote(name)
7813
7814 def format_index(self, index: Index) -> str:
7815 name = self.format_constraint(index)
7816 assert name is not None
7817 return name
7818
7819 def format_table(
7820 self,
7821 table: FromClause,
7822 use_schema: bool = True,
7823 name: Optional[str] = None,
7824 ) -> str:
7825 """Prepare a quoted table and schema name."""
7826 if name is None:
7827 if TYPE_CHECKING:
7828 assert isinstance(table, NamedFromClause)
7829 name = table.name
7830
7831 result = self.quote(name)
7832
7833 effective_schema = self.schema_for_object(table)
7834
7835 if not self.omit_schema and use_schema and effective_schema:
7836 result = self.quote_schema(effective_schema) + "." + result
7837 return result
7838
7839 def format_schema(self, name):
7840 """Prepare a quoted schema name."""
7841
7842 return self.quote(name)
7843
7844 def format_label_name(
7845 self,
7846 name,
7847 anon_map=None,
7848 ):
7849 """Prepare a quoted column name."""
7850
7851 if anon_map is not None and isinstance(
7852 name, elements._truncated_label
7853 ):
7854 name = name.apply_map(anon_map)
7855
7856 return self.quote(name)
7857
7858 def format_column(
7859 self,
7860 column: ColumnElement[Any],
7861 use_table: bool = False,
7862 name: Optional[str] = None,
7863 table_name: Optional[str] = None,
7864 use_schema: bool = False,
7865 anon_map: Optional[Mapping[str, Any]] = None,
7866 ) -> str:
7867 """Prepare a quoted column name."""
7868
7869 if name is None:
7870 name = column.name
7871 assert name is not None
7872
7873 if anon_map is not None and isinstance(
7874 name, elements._truncated_label
7875 ):
7876 name = name.apply_map(anon_map)
7877
7878 if not getattr(column, "is_literal", False):
7879 if use_table:
7880 return (
7881 self.format_table(
7882 column.table, use_schema=use_schema, name=table_name
7883 )
7884 + "."
7885 + self.quote(name)
7886 )
7887 else:
7888 return self.quote(name)
7889 else:
7890 # literal textual elements get stuck into ColumnClause a lot,
7891 # which shouldn't get quoted
7892
7893 if use_table:
7894 return (
7895 self.format_table(
7896 column.table, use_schema=use_schema, name=table_name
7897 )
7898 + "."
7899 + name
7900 )
7901 else:
7902 return name
7903
7904 def format_table_seq(self, table, use_schema=True):
7905 """Format table name and schema as a tuple."""
7906
7907 # Dialects with more levels in their fully qualified references
7908 # ('database', 'owner', etc.) could override this and return
7909 # a longer sequence.
7910
7911 effective_schema = self.schema_for_object(table)
7912
7913 if not self.omit_schema and use_schema and effective_schema:
7914 return (
7915 self.quote_schema(effective_schema),
7916 self.format_table(table, use_schema=False),
7917 )
7918 else:
7919 return (self.format_table(table, use_schema=False),)
7920
7921 @util.memoized_property
7922 def _r_identifiers(self):
7923 initial, final, escaped_final = (
7924 re.escape(s)
7925 for s in (
7926 self.initial_quote,
7927 self.final_quote,
7928 self._escape_identifier(self.final_quote),
7929 )
7930 )
7931 r = re.compile(
7932 r"(?:"
7933 r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
7934 r"|([^\.]+))(?=\.|$))+"
7935 % {"initial": initial, "final": final, "escaped": escaped_final}
7936 )
7937 return r
7938
7939 def unformat_identifiers(self, identifiers: str) -> Sequence[str]:
7940 """Unpack 'schema.table.column'-like strings into components."""
7941
7942 r = self._r_identifiers
7943 return [
7944 self._unescape_identifier(i)
7945 for i in [a or b for a, b in r.findall(identifiers)]
7946 ]