1# sql/crud.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Functions used by compiler.py to determine the parameters rendered
9within INSERT and UPDATE statements.
10
11"""
12import operator
13
14from . import dml
15from . import elements
16from .. import exc
17from .. import util
18
19
20REQUIRED = util.symbol(
21 "REQUIRED",
22 """
23Placeholder for the value within a :class:`.BindParameter`
24which is required to be present when the statement is passed
25to :meth:`_engine.Connection.execute`.
26
27This symbol is typically used when a :func:`_expression.insert`
28or :func:`_expression.update` statement is compiled without parameter
29values present.
30
31""",
32)
33
34ISINSERT = util.symbol("ISINSERT")
35ISUPDATE = util.symbol("ISUPDATE")
36ISDELETE = util.symbol("ISDELETE")
37
38
39def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
40 restore_isinsert = compiler.isinsert
41 restore_isupdate = compiler.isupdate
42 restore_isdelete = compiler.isdelete
43
44 should_restore = (
45 (restore_isinsert or restore_isupdate or restore_isdelete)
46 or len(compiler.stack) > 1
47 or "visiting_cte" in kw
48 )
49
50 if local_stmt_type is ISINSERT:
51 compiler.isupdate = False
52 compiler.isinsert = True
53 elif local_stmt_type is ISUPDATE:
54 compiler.isupdate = True
55 compiler.isinsert = False
56 elif local_stmt_type is ISDELETE:
57 if not should_restore:
58 compiler.isdelete = True
59 else:
60 assert False, "ISINSERT, ISUPDATE, or ISDELETE expected"
61
62 try:
63 if local_stmt_type in (ISINSERT, ISUPDATE):
64 return _get_crud_params(compiler, stmt, **kw)
65 finally:
66 if should_restore:
67 compiler.isinsert = restore_isinsert
68 compiler.isupdate = restore_isupdate
69 compiler.isdelete = restore_isdelete
70
71
72def _get_crud_params(compiler, stmt, **kw):
73 """create a set of tuples representing column/string pairs for use
74 in an INSERT or UPDATE statement.
75
76 Also generates the Compiled object's postfetch, prefetch, and
77 returning column collections, used for default handling and ultimately
78 populating the ResultProxy's prefetch_cols() and postfetch_cols()
79 collections.
80
81 """
82
83 compiler.postfetch = []
84 compiler.insert_prefetch = []
85 compiler.update_prefetch = []
86 compiler.returning = []
87
88 # no parameters in the statement, no parameters in the
89 # compiled params - return binds for all columns
90 if compiler.column_keys is None and stmt.parameters is None:
91 return [
92 (c, _create_bind_param(compiler, c, None, required=True))
93 for c in stmt.table.columns
94 ]
95
96 if stmt._has_multi_parameters:
97 stmt_parameters = stmt.parameters[0]
98 else:
99 stmt_parameters = stmt.parameters
100
101 # getters - these are normally just column.key,
102 # but in the case of mysql multi-table update, the rules for
103 # .key must conditionally take tablename into account
104 (
105 _column_as_key,
106 _getattr_col_key,
107 _col_bind_name,
108 ) = _key_getters_for_crud_column(compiler, stmt)
109
110 # if we have statement parameters - set defaults in the
111 # compiled params
112 if compiler.column_keys is None:
113 parameters = {}
114 else:
115 parameters = dict(
116 (_column_as_key(key), REQUIRED)
117 for key in compiler.column_keys
118 if not stmt_parameters or key not in stmt_parameters
119 )
120
121 # create a list of column assignment clauses as tuples
122 values = []
123
124 if stmt_parameters is not None:
125 _get_stmt_parameters_params(
126 compiler, parameters, stmt_parameters, _column_as_key, values, kw
127 )
128
129 check_columns = {}
130
131 # special logic that only occurs for multi-table UPDATE
132 # statements
133 if compiler.isupdate and stmt._extra_froms and stmt_parameters:
134 _get_multitable_params(
135 compiler,
136 stmt,
137 stmt_parameters,
138 check_columns,
139 _col_bind_name,
140 _getattr_col_key,
141 values,
142 kw,
143 )
144
145 if compiler.isinsert and stmt.select_names:
146 _scan_insert_from_select_cols(
147 compiler,
148 stmt,
149 parameters,
150 _getattr_col_key,
151 _column_as_key,
152 _col_bind_name,
153 check_columns,
154 values,
155 kw,
156 )
157 else:
158 _scan_cols(
159 compiler,
160 stmt,
161 parameters,
162 _getattr_col_key,
163 _column_as_key,
164 _col_bind_name,
165 check_columns,
166 values,
167 kw,
168 )
169
170 if parameters and stmt_parameters:
171 check = (
172 set(parameters)
173 .intersection(_column_as_key(k) for k in stmt_parameters)
174 .difference(check_columns)
175 )
176 if check:
177 raise exc.CompileError(
178 "Unconsumed column names: %s"
179 % (", ".join("%s" % c for c in check))
180 )
181
182 if stmt._has_multi_parameters:
183 values = _extend_values_for_multiparams(compiler, stmt, values, kw)
184
185 return values
186
187
188def _create_bind_param(
189 compiler, col, value, process=True, required=False, name=None, **kw
190):
191 if name is None:
192 name = col.key
193 bindparam = elements.BindParameter(
194 name, value, type_=col.type, required=required
195 )
196 bindparam._is_crud = True
197 if process:
198 bindparam = bindparam._compiler_dispatch(compiler, **kw)
199 return bindparam
200
201
202def _key_getters_for_crud_column(compiler, stmt):
203 if compiler.isupdate and stmt._extra_froms:
204 # when extra tables are present, refer to the columns
205 # in those extra tables as table-qualified, including in
206 # dictionaries and when rendering bind param names.
207 # the "main" table of the statement remains unqualified,
208 # allowing the most compatibility with a non-multi-table
209 # statement.
210 _et = set(stmt._extra_froms)
211
212 def _column_as_key(key):
213 str_key = elements._column_as_key(key)
214 if hasattr(key, "table") and key.table in _et:
215 return (key.table.name, str_key)
216 else:
217 return str_key
218
219 def _getattr_col_key(col):
220 if col.table in _et:
221 return (col.table.name, col.key)
222 else:
223 return col.key
224
225 def _col_bind_name(col):
226 if col.table in _et:
227 return "%s_%s" % (col.table.name, col.key)
228 else:
229 return col.key
230
231 else:
232 _column_as_key = elements._column_as_key
233 _getattr_col_key = _col_bind_name = operator.attrgetter("key")
234
235 return _column_as_key, _getattr_col_key, _col_bind_name
236
237
238def _scan_insert_from_select_cols(
239 compiler,
240 stmt,
241 parameters,
242 _getattr_col_key,
243 _column_as_key,
244 _col_bind_name,
245 check_columns,
246 values,
247 kw,
248):
249
250 (
251 need_pks,
252 implicit_returning,
253 implicit_return_defaults,
254 postfetch_lastrowid,
255 ) = _get_returning_modifiers(compiler, stmt)
256
257 cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names]
258
259 compiler._insert_from_select = stmt.select
260
261 add_select_cols = []
262 if stmt.include_insert_from_select_defaults:
263 col_set = set(cols)
264 for col in stmt.table.columns:
265 if col not in col_set and col.default:
266 cols.append(col)
267
268 for c in cols:
269 col_key = _getattr_col_key(c)
270 if col_key in parameters and col_key not in check_columns:
271 parameters.pop(col_key)
272 values.append((c, None))
273 else:
274 _append_param_insert_select_hasdefault(
275 compiler, stmt, c, add_select_cols, kw
276 )
277
278 if add_select_cols:
279 values.extend(add_select_cols)
280 compiler._insert_from_select = compiler._insert_from_select._generate()
281 compiler._insert_from_select._raw_columns = tuple(
282 compiler._insert_from_select._raw_columns
283 ) + tuple(expr for col, expr in add_select_cols)
284
285
286def _scan_cols(
287 compiler,
288 stmt,
289 parameters,
290 _getattr_col_key,
291 _column_as_key,
292 _col_bind_name,
293 check_columns,
294 values,
295 kw,
296):
297
298 (
299 need_pks,
300 implicit_returning,
301 implicit_return_defaults,
302 postfetch_lastrowid,
303 ) = _get_returning_modifiers(compiler, stmt)
304
305 if stmt._parameter_ordering:
306 parameter_ordering = [
307 _column_as_key(key) for key in stmt._parameter_ordering
308 ]
309 ordered_keys = set(parameter_ordering)
310 cols = [stmt.table.c[key] for key in parameter_ordering] + [
311 c for c in stmt.table.c if c.key not in ordered_keys
312 ]
313 else:
314 cols = stmt.table.columns
315
316 for c in cols:
317 col_key = _getattr_col_key(c)
318
319 if col_key in parameters and col_key not in check_columns:
320
321 _append_param_parameter(
322 compiler,
323 stmt,
324 c,
325 col_key,
326 parameters,
327 _col_bind_name,
328 implicit_returning,
329 implicit_return_defaults,
330 values,
331 kw,
332 )
333
334 elif compiler.isinsert:
335 if (
336 c.primary_key
337 and need_pks
338 and (
339 implicit_returning
340 or not postfetch_lastrowid
341 or c is not stmt.table._autoincrement_column
342 )
343 ):
344
345 if implicit_returning:
346 _append_param_insert_pk_returning(
347 compiler, stmt, c, values, kw
348 )
349 else:
350 _append_param_insert_pk(compiler, stmt, c, values, kw)
351
352 elif c.default is not None:
353
354 _append_param_insert_hasdefault(
355 compiler, stmt, c, implicit_return_defaults, values, kw
356 )
357
358 elif c.server_default is not None:
359 if implicit_return_defaults and c in implicit_return_defaults:
360 compiler.returning.append(c)
361 elif not c.primary_key:
362 compiler.postfetch.append(c)
363 elif implicit_return_defaults and c in implicit_return_defaults:
364 compiler.returning.append(c)
365 elif (
366 c.primary_key
367 and c is not stmt.table._autoincrement_column
368 and not c.nullable
369 ):
370 _warn_pk_with_no_anticipated_value(c)
371
372 elif compiler.isupdate:
373 _append_param_update(
374 compiler, stmt, c, implicit_return_defaults, values, kw
375 )
376
377
378def _append_param_parameter(
379 compiler,
380 stmt,
381 c,
382 col_key,
383 parameters,
384 _col_bind_name,
385 implicit_returning,
386 implicit_return_defaults,
387 values,
388 kw,
389):
390 value = parameters.pop(col_key)
391 if elements._is_literal(value):
392 value = _create_bind_param(
393 compiler,
394 c,
395 value,
396 required=value is REQUIRED,
397 name=_col_bind_name(c)
398 if not stmt._has_multi_parameters
399 else "%s_m0" % _col_bind_name(c),
400 **kw
401 )
402 else:
403 if isinstance(value, elements.BindParameter) and value.type._isnull:
404 value = value._clone()
405 value.type = c.type
406
407 if c.primary_key and implicit_returning:
408 compiler.returning.append(c)
409 value = compiler.process(value.self_group(), **kw)
410 elif implicit_return_defaults and c in implicit_return_defaults:
411 compiler.returning.append(c)
412 value = compiler.process(value.self_group(), **kw)
413 else:
414 # postfetch specifically means, "we can SELECT the row we just
415 # inserted by primary key to get back the server generated
416 # defaults". so by definition this can't be used to get the primary
417 # key value back, because we need to have it ahead of time.
418 if not c.primary_key:
419 compiler.postfetch.append(c)
420 value = compiler.process(value.self_group(), **kw)
421 values.append((c, value))
422
423
424def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
425 """Create a primary key expression in the INSERT statement and
426 possibly a RETURNING clause for it.
427
428 If the column has a Python-side default, we will create a bound
429 parameter for it and "pre-execute" the Python function. If
430 the column has a SQL expression default, or is a sequence,
431 we will add it directly into the INSERT statement and add a
432 RETURNING element to get the new value. If the column has a
433 server side default or is marked as the "autoincrement" column,
434 we will add a RETRUNING element to get at the value.
435
436 If all the above tests fail, that indicates a primary key column with no
437 noted default generation capabilities that has no parameter passed;
438 raise an exception.
439
440 """
441 if c.default is not None:
442 if c.default.is_sequence:
443 if compiler.dialect.supports_sequences and (
444 not c.default.optional
445 or not compiler.dialect.sequences_optional
446 ):
447 proc = compiler.process(c.default, **kw)
448 values.append((c, proc))
449 compiler.returning.append(c)
450 elif c.default.is_clause_element:
451 values.append(
452 (c, compiler.process(c.default.arg.self_group(), **kw))
453 )
454 compiler.returning.append(c)
455 else:
456 values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
457 elif c is stmt.table._autoincrement_column or c.server_default is not None:
458 compiler.returning.append(c)
459 elif not c.nullable:
460 # no .default, no .server_default, not autoincrement, we have
461 # no indication this primary key column will have any value
462 _warn_pk_with_no_anticipated_value(c)
463
464
465def _create_insert_prefetch_bind_param(compiler, c, process=True, name=None):
466 param = _create_bind_param(compiler, c, None, process=process, name=name)
467 compiler.insert_prefetch.append(c)
468 return param
469
470
471def _create_update_prefetch_bind_param(compiler, c, process=True, name=None):
472 param = _create_bind_param(compiler, c, None, process=process, name=name)
473 compiler.update_prefetch.append(c)
474 return param
475
476
477class _multiparam_column(elements.ColumnElement):
478 _is_multiparam_column = True
479
480 def __init__(self, original, index):
481 self.index = index
482 self.key = "%s_m%d" % (original.key, index + 1)
483 self.original = original
484 self.default = original.default
485 self.type = original.type
486
487 def __eq__(self, other):
488 return (
489 isinstance(other, _multiparam_column)
490 and other.key == self.key
491 and other.original == self.original
492 )
493
494
495def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
496
497 if not c.default:
498 raise exc.CompileError(
499 "INSERT value for column %s is explicitly rendered as a bound"
500 "parameter in the VALUES clause; "
501 "a Python-side value or SQL expression is required" % c
502 )
503 elif c.default.is_clause_element:
504 return compiler.process(c.default.arg.self_group(), **kw)
505 else:
506 col = _multiparam_column(c, index)
507 if isinstance(stmt, dml.Insert):
508 return _create_insert_prefetch_bind_param(compiler, col)
509 else:
510 return _create_update_prefetch_bind_param(compiler, col)
511
512
513def _append_param_insert_pk(compiler, stmt, c, values, kw):
514 """Create a bound parameter in the INSERT statement to receive a
515 'prefetched' default value.
516
517 The 'prefetched' value indicates that we are to invoke a Python-side
518 default function or expliclt SQL expression before the INSERT statement
519 proceeds, so that we have a primary key value available.
520
521 if the column has no noted default generation capabilities, it has
522 no value passed in either; raise an exception.
523
524 """
525 if (
526 # column has a Python-side default
527 c.default is not None
528 and (
529 # and it won't be a Sequence
530 not c.default.is_sequence
531 or compiler.dialect.supports_sequences
532 )
533 ) or (
534 # column is the "autoincrement column"
535 c is stmt.table._autoincrement_column
536 and (
537 # and it's either a "sequence" or a
538 # pre-executable "autoincrement" sequence
539 compiler.dialect.supports_sequences
540 or compiler.dialect.preexecute_autoincrement_sequences
541 )
542 ):
543 values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
544 elif c.default is None and c.server_default is None and not c.nullable:
545 # no .default, no .server_default, not autoincrement, we have
546 # no indication this primary key column will have any value
547 _warn_pk_with_no_anticipated_value(c)
548
549
550def _append_param_insert_hasdefault(
551 compiler, stmt, c, implicit_return_defaults, values, kw
552):
553
554 if c.default.is_sequence:
555 if compiler.dialect.supports_sequences and (
556 not c.default.optional or not compiler.dialect.sequences_optional
557 ):
558 proc = compiler.process(c.default, **kw)
559 values.append((c, proc))
560 if implicit_return_defaults and c in implicit_return_defaults:
561 compiler.returning.append(c)
562 elif not c.primary_key:
563 compiler.postfetch.append(c)
564 elif c.default.is_clause_element:
565 proc = compiler.process(c.default.arg.self_group(), **kw)
566 values.append((c, proc))
567
568 if implicit_return_defaults and c in implicit_return_defaults:
569 compiler.returning.append(c)
570 elif not c.primary_key:
571 # don't add primary key column to postfetch
572 compiler.postfetch.append(c)
573 else:
574 values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
575
576
577def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
578
579 if c.default.is_sequence:
580 if compiler.dialect.supports_sequences and (
581 not c.default.optional or not compiler.dialect.sequences_optional
582 ):
583 proc = c.default
584 values.append((c, proc.next_value()))
585 elif c.default.is_clause_element:
586 proc = c.default.arg.self_group()
587 values.append((c, proc))
588 else:
589 values.append(
590 (c, _create_insert_prefetch_bind_param(compiler, c, process=False))
591 )
592
593
594def _append_param_update(
595 compiler, stmt, c, implicit_return_defaults, values, kw
596):
597
598 if c.onupdate is not None and not c.onupdate.is_sequence:
599 if c.onupdate.is_clause_element:
600 values.append(
601 (c, compiler.process(c.onupdate.arg.self_group(), **kw))
602 )
603 if implicit_return_defaults and c in implicit_return_defaults:
604 compiler.returning.append(c)
605 else:
606 compiler.postfetch.append(c)
607 else:
608 values.append((c, _create_update_prefetch_bind_param(compiler, c)))
609 elif c.server_onupdate is not None:
610 if implicit_return_defaults and c in implicit_return_defaults:
611 compiler.returning.append(c)
612 else:
613 compiler.postfetch.append(c)
614 elif (
615 implicit_return_defaults
616 and stmt._return_defaults is not True
617 and c in implicit_return_defaults
618 ):
619 compiler.returning.append(c)
620
621
622def _get_multitable_params(
623 compiler,
624 stmt,
625 stmt_parameters,
626 check_columns,
627 _col_bind_name,
628 _getattr_col_key,
629 values,
630 kw,
631):
632
633 normalized_params = dict(
634 (elements._clause_element_as_expr(c), param)
635 for c, param in stmt_parameters.items()
636 )
637 affected_tables = set()
638 for t in stmt._extra_froms:
639 for c in t.c:
640 if c in normalized_params:
641 affected_tables.add(t)
642 check_columns[_getattr_col_key(c)] = c
643 value = normalized_params[c]
644 if elements._is_literal(value):
645 value = _create_bind_param(
646 compiler,
647 c,
648 value,
649 required=value is REQUIRED,
650 name=_col_bind_name(c),
651 )
652 else:
653 compiler.postfetch.append(c)
654 value = compiler.process(value.self_group(), **kw)
655 values.append((c, value))
656 # determine tables which are actually to be updated - process onupdate
657 # and server_onupdate for these
658 for t in affected_tables:
659 for c in t.c:
660 if c in normalized_params:
661 continue
662 elif c.onupdate is not None and not c.onupdate.is_sequence:
663 if c.onupdate.is_clause_element:
664 values.append(
665 (
666 c,
667 compiler.process(
668 c.onupdate.arg.self_group(), **kw
669 ),
670 )
671 )
672 compiler.postfetch.append(c)
673 else:
674 values.append(
675 (
676 c,
677 _create_update_prefetch_bind_param(
678 compiler, c, name=_col_bind_name(c)
679 ),
680 )
681 )
682 elif c.server_onupdate is not None:
683 compiler.postfetch.append(c)
684
685
686def _extend_values_for_multiparams(compiler, stmt, values, kw):
687 values_0 = values
688 values = [values]
689
690 for i, row in enumerate(stmt.parameters[1:]):
691 extension = []
692 for (col, param) in values_0:
693 if col in row or col.key in row:
694 key = col if col in row else col.key
695
696 if elements._is_literal(row[key]):
697 new_param = _create_bind_param(
698 compiler,
699 col,
700 row[key],
701 name="%s_m%d" % (col.key, i + 1),
702 **kw
703 )
704 else:
705 new_param = compiler.process(row[key].self_group(), **kw)
706 else:
707 new_param = _process_multiparam_default_bind(
708 compiler, stmt, col, i, kw
709 )
710
711 extension.append((col, new_param))
712
713 values.append(extension)
714
715 return values
716
717
718def _get_stmt_parameters_params(
719 compiler, parameters, stmt_parameters, _column_as_key, values, kw
720):
721 for k, v in stmt_parameters.items():
722 colkey = _column_as_key(k)
723 if colkey is not None:
724 parameters.setdefault(colkey, v)
725 else:
726 # a non-Column expression on the left side;
727 # add it to values() in an "as-is" state,
728 # coercing right side to bound param
729 if elements._is_literal(v):
730 v = compiler.process(
731 elements.BindParameter(None, v, type_=k.type), **kw
732 )
733 else:
734 if v._is_bind_parameter and v.type._isnull:
735 # either unique parameter, or other bound parameters that
736 # were passed in directly
737 # set type to that of the column unconditionally
738 v = v._with_binary_element_type(k.type)
739
740 v = compiler.process(v.self_group(), **kw)
741
742 values.append((k, v))
743
744
745def _get_returning_modifiers(compiler, stmt):
746 need_pks = (
747 compiler.isinsert
748 and not compiler.inline
749 and not stmt._returning
750 and not stmt._has_multi_parameters
751 )
752
753 implicit_returning = (
754 need_pks
755 and compiler.dialect.implicit_returning
756 and stmt.table.implicit_returning
757 )
758
759 if compiler.isinsert:
760 implicit_return_defaults = implicit_returning and stmt._return_defaults
761 elif compiler.isupdate:
762 implicit_return_defaults = (
763 compiler.dialect.implicit_returning
764 and stmt.table.implicit_returning
765 and stmt._return_defaults
766 )
767 else:
768 # this line is unused, currently we are always
769 # isinsert or isupdate
770 implicit_return_defaults = False # pragma: no cover
771
772 if implicit_return_defaults:
773 if stmt._return_defaults is True:
774 implicit_return_defaults = set(stmt.table.c)
775 else:
776 implicit_return_defaults = set(stmt._return_defaults)
777
778 postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
779
780 return (
781 need_pks,
782 implicit_returning,
783 implicit_return_defaults,
784 postfetch_lastrowid,
785 )
786
787
788def _warn_pk_with_no_anticipated_value(c):
789 msg = (
790 "Column '%s.%s' is marked as a member of the "
791 "primary key for table '%s', "
792 "but has no Python-side or server-side default generator indicated, "
793 "nor does it indicate 'autoincrement=True' or 'nullable=True', "
794 "and no explicit value is passed. "
795 "Primary key columns typically may not store NULL."
796 % (c.table.fullname, c.name, c.table.fullname)
797 )
798 if len(c.table.primary_key) > 1:
799 msg += (
800 " Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be "
801 "indicated explicitly for composite (e.g. multicolumn) primary "
802 "keys if AUTO_INCREMENT/SERIAL/IDENTITY "
803 "behavior is expected for one of the columns in the primary key. "
804 "CREATE TABLE statements are impacted by this change as well on "
805 "most backends."
806 )
807 util.warn(msg)