1# sql/base.py
2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
3# <see AUTHORS file>
4#
5# This module is part of SQLAlchemy and is released under
6# the MIT License: http://www.opensource.org/licenses/mit-license.php
7
8"""Foundational utilities common to many sql modules.
9
10"""
11
12
13import itertools
14import re
15
16from .visitors import ClauseVisitor
17from .. import exc
18from .. import util
19
20
21PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT")
22NO_ARG = util.symbol("NO_ARG")
23
24
25class Immutable(object):
26 """mark a ClauseElement as 'immutable' when expressions are cloned."""
27
28 def unique_params(self, *optionaldict, **kwargs):
29 raise NotImplementedError("Immutable objects do not support copying")
30
31 def params(self, *optionaldict, **kwargs):
32 raise NotImplementedError("Immutable objects do not support copying")
33
34 def _clone(self):
35 return self
36
37
38def _from_objects(*elements):
39 return itertools.chain(*[element._from_objects for element in elements])
40
41
42@util.decorator
43def _generative(fn, *args, **kw):
44 """Mark a method as generative."""
45
46 self = args[0]._generate()
47 fn(self, *args[1:], **kw)
48 return self
49
50
51class _DialectArgView(util.collections_abc.MutableMapping):
52 """A dictionary view of dialect-level arguments in the form
53 <dialectname>_<argument_name>.
54
55 """
56
57 def __init__(self, obj):
58 self.obj = obj
59
60 def _key(self, key):
61 try:
62 dialect, value_key = key.split("_", 1)
63 except ValueError as err:
64 util.raise_(KeyError(key), replace_context=err)
65 else:
66 return dialect, value_key
67
68 def __getitem__(self, key):
69 dialect, value_key = self._key(key)
70
71 try:
72 opt = self.obj.dialect_options[dialect]
73 except exc.NoSuchModuleError as err:
74 util.raise_(KeyError(key), replace_context=err)
75 else:
76 return opt[value_key]
77
78 def __setitem__(self, key, value):
79 try:
80 dialect, value_key = self._key(key)
81 except KeyError as err:
82 util.raise_(
83 exc.ArgumentError(
84 "Keys must be of the form <dialectname>_<argname>"
85 ),
86 replace_context=err,
87 )
88 else:
89 self.obj.dialect_options[dialect][value_key] = value
90
91 def __delitem__(self, key):
92 dialect, value_key = self._key(key)
93 del self.obj.dialect_options[dialect][value_key]
94
95 def __len__(self):
96 return sum(
97 len(args._non_defaults)
98 for args in self.obj.dialect_options.values()
99 )
100
101 def __iter__(self):
102 return (
103 util.safe_kwarg("%s_%s" % (dialect_name, value_name))
104 for dialect_name in self.obj.dialect_options
105 for value_name in self.obj.dialect_options[
106 dialect_name
107 ]._non_defaults
108 )
109
110
111class _DialectArgDict(util.collections_abc.MutableMapping):
112 """A dictionary view of dialect-level arguments for a specific
113 dialect.
114
115 Maintains a separate collection of user-specified arguments
116 and dialect-specified default arguments.
117
118 """
119
120 def __init__(self):
121 self._non_defaults = {}
122 self._defaults = {}
123
124 def __len__(self):
125 return len(set(self._non_defaults).union(self._defaults))
126
127 def __iter__(self):
128 return iter(set(self._non_defaults).union(self._defaults))
129
130 def __getitem__(self, key):
131 if key in self._non_defaults:
132 return self._non_defaults[key]
133 else:
134 return self._defaults[key]
135
136 def __setitem__(self, key, value):
137 self._non_defaults[key] = value
138
139 def __delitem__(self, key):
140 del self._non_defaults[key]
141
142
143class DialectKWArgs(object):
144 """Establish the ability for a class to have dialect-specific arguments
145 with defaults and constructor validation.
146
147 The :class:`.DialectKWArgs` interacts with the
148 :attr:`.DefaultDialect.construct_arguments` present on a dialect.
149
150 .. seealso::
151
152 :attr:`.DefaultDialect.construct_arguments`
153
154 """
155
156 @classmethod
157 def argument_for(cls, dialect_name, argument_name, default):
158 """Add a new kind of dialect-specific keyword argument for this class.
159
160 E.g.::
161
162 Index.argument_for("mydialect", "length", None)
163
164 some_index = Index('a', 'b', mydialect_length=5)
165
166 The :meth:`.DialectKWArgs.argument_for` method is a per-argument
167 way adding extra arguments to the
168 :attr:`.DefaultDialect.construct_arguments` dictionary. This
169 dictionary provides a list of argument names accepted by various
170 schema-level constructs on behalf of a dialect.
171
172 New dialects should typically specify this dictionary all at once as a
173 data member of the dialect class. The use case for ad-hoc addition of
174 argument names is typically for end-user code that is also using
175 a custom compilation scheme which consumes the additional arguments.
176
177 :param dialect_name: name of a dialect. The dialect must be
178 locatable, else a :class:`.NoSuchModuleError` is raised. The
179 dialect must also include an existing
180 :attr:`.DefaultDialect.construct_arguments` collection, indicating
181 that it participates in the keyword-argument validation and default
182 system, else :class:`.ArgumentError` is raised. If the dialect does
183 not include this collection, then any keyword argument can be
184 specified on behalf of this dialect already. All dialects packaged
185 within SQLAlchemy include this collection, however for third party
186 dialects, support may vary.
187
188 :param argument_name: name of the parameter.
189
190 :param default: default value of the parameter.
191
192 .. versionadded:: 0.9.4
193
194 """
195
196 construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
197 if construct_arg_dictionary is None:
198 raise exc.ArgumentError(
199 "Dialect '%s' does have keyword-argument "
200 "validation and defaults enabled configured" % dialect_name
201 )
202 if cls not in construct_arg_dictionary:
203 construct_arg_dictionary[cls] = {}
204 construct_arg_dictionary[cls][argument_name] = default
205
206 @util.memoized_property
207 def dialect_kwargs(self):
208 """A collection of keyword arguments specified as dialect-specific
209 options to this construct.
210
211 The arguments are present here in their original ``<dialect>_<kwarg>``
212 format. Only arguments that were actually passed are included;
213 unlike the :attr:`.DialectKWArgs.dialect_options` collection, which
214 contains all options known by this dialect including defaults.
215
216 The collection is also writable; keys are accepted of the
217 form ``<dialect>_<kwarg>`` where the value will be assembled
218 into the list of options.
219
220 .. versionadded:: 0.9.2
221
222 .. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs`
223 collection is now writable.
224
225 .. seealso::
226
227 :attr:`.DialectKWArgs.dialect_options` - nested dictionary form
228
229 """
230 return _DialectArgView(self)
231
232 @property
233 def kwargs(self):
234 """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`."""
235 return self.dialect_kwargs
236
237 @util.dependencies("sqlalchemy.dialects")
238 def _kw_reg_for_dialect(dialects, dialect_name):
239 dialect_cls = dialects.registry.load(dialect_name)
240 if dialect_cls.construct_arguments is None:
241 return None
242 return dict(dialect_cls.construct_arguments)
243
244 _kw_registry = util.PopulateDict(_kw_reg_for_dialect)
245
246 def _kw_reg_for_dialect_cls(self, dialect_name):
247 construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name]
248 d = _DialectArgDict()
249
250 if construct_arg_dictionary is None:
251 d._defaults.update({"*": None})
252 else:
253 for cls in reversed(self.__class__.__mro__):
254 if cls in construct_arg_dictionary:
255 d._defaults.update(construct_arg_dictionary[cls])
256 return d
257
258 @util.memoized_property
259 def dialect_options(self):
260 """A collection of keyword arguments specified as dialect-specific
261 options to this construct.
262
263 This is a two-level nested registry, keyed to ``<dialect_name>``
264 and ``<argument_name>``. For example, the ``postgresql_where``
265 argument would be locatable as::
266
267 arg = my_object.dialect_options['postgresql']['where']
268
269 .. versionadded:: 0.9.2
270
271 .. seealso::
272
273 :attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form
274
275 """
276
277 return util.PopulateDict(
278 util.portable_instancemethod(self._kw_reg_for_dialect_cls)
279 )
280
281 def _validate_dialect_kwargs(self, kwargs):
282 # validate remaining kwargs that they all specify DB prefixes
283
284 if not kwargs:
285 return
286
287 for k in kwargs:
288 m = re.match("^(.+?)_(.+)$", k)
289 if not m:
290 raise TypeError(
291 "Additional arguments should be "
292 "named <dialectname>_<argument>, got '%s'" % k
293 )
294 dialect_name, arg_name = m.group(1, 2)
295
296 try:
297 construct_arg_dictionary = self.dialect_options[dialect_name]
298 except exc.NoSuchModuleError:
299 util.warn(
300 "Can't validate argument %r; can't "
301 "locate any SQLAlchemy dialect named %r"
302 % (k, dialect_name)
303 )
304 self.dialect_options[dialect_name] = d = _DialectArgDict()
305 d._defaults.update({"*": None})
306 d._non_defaults[arg_name] = kwargs[k]
307 else:
308 if (
309 "*" not in construct_arg_dictionary
310 and arg_name not in construct_arg_dictionary
311 ):
312 raise exc.ArgumentError(
313 "Argument %r is not accepted by "
314 "dialect %r on behalf of %r"
315 % (k, dialect_name, self.__class__)
316 )
317 else:
318 construct_arg_dictionary[arg_name] = kwargs[k]
319
320
321class Generative(object):
322 """Allow a ClauseElement to generate itself via the
323 @_generative decorator.
324
325 """
326
327 def _generate(self):
328 s = self.__class__.__new__(self.__class__)
329 s.__dict__ = self.__dict__.copy()
330 return s
331
332
333class Executable(Generative):
334 """Mark a :class:`_expression.ClauseElement` as supporting execution.
335
336 :class:`.Executable` is a superclass for all "statement" types
337 of objects, including :func:`select`, :func:`delete`, :func:`update`,
338 :func:`insert`, :func:`text`.
339
340 """
341
342 supports_execution = True
343 _execution_options = util.immutabledict()
344 _bind = None
345
346 @_generative
347 def execution_options(self, **kw):
348 """Set non-SQL options for the statement which take effect during
349 execution.
350
351 Execution options can be set on a per-statement or
352 per :class:`_engine.Connection` basis. Additionally, the
353 :class:`_engine.Engine` and ORM :class:`~.orm.query.Query`
354 objects provide
355 access to execution options which they in turn configure upon
356 connections.
357
358 The :meth:`execution_options` method is generative. A new
359 instance of this statement is returned that contains the options::
360
361 statement = select([table.c.x, table.c.y])
362 statement = statement.execution_options(autocommit=True)
363
364 Note that only a subset of possible execution options can be applied
365 to a statement - these include "autocommit" and "stream_results",
366 but not "isolation_level" or "compiled_cache".
367 See :meth:`_engine.Connection.execution_options` for a full list of
368 possible options.
369
370 .. seealso::
371
372 :meth:`_engine.Connection.execution_options`
373
374 :meth:`_query.Query.execution_options`
375
376 :meth:`.Executable.get_execution_options`
377
378 """
379 if "isolation_level" in kw:
380 raise exc.ArgumentError(
381 "'isolation_level' execution option may only be specified "
382 "on Connection.execution_options(), or "
383 "per-engine using the isolation_level "
384 "argument to create_engine()."
385 )
386 if "compiled_cache" in kw:
387 raise exc.ArgumentError(
388 "'compiled_cache' execution option may only be specified "
389 "on Connection.execution_options(), not per statement."
390 )
391 self._execution_options = self._execution_options.union(kw)
392
393 def get_execution_options(self):
394 """Get the non-SQL options which will take effect during execution.
395
396 .. versionadded:: 1.3
397
398 .. seealso::
399
400 :meth:`.Executable.execution_options`
401
402 """
403 return self._execution_options
404
405 def execute(self, *multiparams, **params):
406 """Compile and execute this :class:`.Executable`."""
407 e = self.bind
408 if e is None:
409 label = getattr(self, "description", self.__class__.__name__)
410 msg = (
411 "This %s is not directly bound to a Connection or Engine. "
412 "Use the .execute() method of a Connection or Engine "
413 "to execute this construct." % label
414 )
415 raise exc.UnboundExecutionError(msg)
416 return e._execute_clauseelement(self, multiparams, params)
417
418 def scalar(self, *multiparams, **params):
419 """Compile and execute this :class:`.Executable`, returning the
420 result's scalar representation.
421
422 """
423 return self.execute(*multiparams, **params).scalar()
424
425 @property
426 def bind(self):
427 """Returns the :class:`_engine.Engine` or :class:`_engine.Connection`
428 to which this :class:`.Executable` is bound, or None if none found.
429
430 This is a traversal which checks locally, then
431 checks among the "from" clauses of associated objects
432 until a bound engine or connection is found.
433
434 """
435 if self._bind is not None:
436 return self._bind
437
438 for f in _from_objects(self):
439 if f is self:
440 continue
441 engine = f.bind
442 if engine is not None:
443 return engine
444 else:
445 return None
446
447
448class SchemaEventTarget(object):
449 """Base class for elements that are the targets of :class:`.DDLEvents`
450 events.
451
452 This includes :class:`.SchemaItem` as well as :class:`.SchemaType`.
453
454 """
455
456 def _set_parent(self, parent, **kw):
457 """Associate with this SchemaEvent's parent object."""
458
459 def _set_parent_with_dispatch(self, parent, **kw):
460 self.dispatch.before_parent_attach(self, parent)
461 self._set_parent(parent, **kw)
462 self.dispatch.after_parent_attach(self, parent)
463
464
465class SchemaVisitor(ClauseVisitor):
466 """Define the visiting for ``SchemaItem`` objects."""
467
468 __traverse_options__ = {"schema_visitor": True}
469
470
471class ColumnCollection(util.OrderedProperties):
472 """An ordered dictionary that stores a list of ColumnElement
473 instances.
474
475 Overrides the ``__eq__()`` method to produce SQL clauses between
476 sets of correlated columns.
477
478 """
479
480 __slots__ = "_all_columns"
481
482 def __init__(self, *columns):
483 super(ColumnCollection, self).__init__()
484 object.__setattr__(self, "_all_columns", [])
485 for c in columns:
486 self.add(c)
487
488 def __str__(self):
489 return repr([str(c) for c in self])
490
491 def replace(self, column):
492 """Add the given column to this collection, removing unaliased
493 versions of this column as well as existing columns with the
494 same key.
495
496 E.g.::
497
498 t = Table('sometable', metadata, Column('col1', Integer))
499 t.columns.replace(Column('col1', Integer, key='columnone'))
500
501 will remove the original 'col1' from the collection, and add
502 the new column under the name 'columnname'.
503
504 Used by schema.Column to override columns during table reflection.
505
506 """
507 remove_col = None
508 if column.name in self and column.key != column.name:
509 other = self[column.name]
510 if other.name == other.key:
511 remove_col = other
512 del self._data[other.key]
513
514 if column.key in self._data:
515 remove_col = self._data[column.key]
516
517 self._data[column.key] = column
518 if remove_col is not None:
519 self._all_columns[:] = [
520 column if c is remove_col else c for c in self._all_columns
521 ]
522 else:
523 self._all_columns.append(column)
524
525 def add(self, column):
526 """Add a column to this collection.
527
528 The key attribute of the column will be used as the hash key
529 for this dictionary.
530
531 """
532 if not column.key:
533 raise exc.ArgumentError(
534 "Can't add unnamed column to column collection"
535 )
536 self[column.key] = column
537
538 def __delitem__(self, key):
539 raise NotImplementedError()
540
541 def __setattr__(self, key, obj):
542 raise NotImplementedError()
543
544 def __setitem__(self, key, value):
545 if key in self:
546
547 # this warning is primarily to catch select() statements
548 # which have conflicting column names in their exported
549 # columns collection
550
551 existing = self[key]
552
553 if existing is value:
554 return
555
556 if not existing.shares_lineage(value):
557 util.warn(
558 "Column %r on table %r being replaced by "
559 "%r, which has the same key. Consider "
560 "use_labels for select() statements."
561 % (key, getattr(existing, "table", None), value)
562 )
563
564 # pop out memoized proxy_set as this
565 # operation may very well be occurring
566 # in a _make_proxy operation
567 util.memoized_property.reset(value, "proxy_set")
568
569 self._all_columns.append(value)
570 self._data[key] = value
571
572 def clear(self):
573 raise NotImplementedError()
574
575 def remove(self, column):
576 del self._data[column.key]
577 self._all_columns[:] = [
578 c for c in self._all_columns if c is not column
579 ]
580
581 def update(self, iter_):
582 cols = list(iter_)
583 all_col_set = set(self._all_columns)
584 self._all_columns.extend(
585 c for label, c in cols if c not in all_col_set
586 )
587 self._data.update((label, c) for label, c in cols)
588
589 def extend(self, iter_):
590 cols = list(iter_)
591 all_col_set = set(self._all_columns)
592 self._all_columns.extend(c for c in cols if c not in all_col_set)
593 self._data.update((c.key, c) for c in cols)
594
595 __hash__ = None
596
597 @util.dependencies("sqlalchemy.sql.elements")
598 def __eq__(self, elements, other):
599 l = []
600 for c in getattr(other, "_all_columns", other):
601 for local in self._all_columns:
602 if c.shares_lineage(local):
603 l.append(c == local)
604 return elements.and_(*l)
605
606 def __contains__(self, other):
607 if not isinstance(other, util.string_types):
608 raise exc.ArgumentError("__contains__ requires a string argument")
609 return util.OrderedProperties.__contains__(self, other)
610
611 def __getstate__(self):
612 return {"_data": self._data, "_all_columns": self._all_columns}
613
614 def __setstate__(self, state):
615 object.__setattr__(self, "_data", state["_data"])
616 object.__setattr__(self, "_all_columns", state["_all_columns"])
617
618 def contains_column(self, col):
619 return col in set(self._all_columns)
620
621 def as_immutable(self):
622 return ImmutableColumnCollection(self._data, self._all_columns)
623
624
625class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
626 def __init__(self, data, all_columns):
627 util.ImmutableProperties.__init__(self, data)
628 object.__setattr__(self, "_all_columns", all_columns)
629
630 extend = remove = util.ImmutableProperties._immutable
631
632
633class ColumnSet(util.ordered_column_set):
634 def contains_column(self, col):
635 return col in self
636
637 def extend(self, cols):
638 for col in cols:
639 self.add(col)
640
641 def __add__(self, other):
642 return list(self) + list(other)
643
644 @util.dependencies("sqlalchemy.sql.elements")
645 def __eq__(self, elements, other):
646 l = []
647 for c in other:
648 for local in self:
649 if c.shares_lineage(local):
650 l.append(c == local)
651 return elements.and_(*l)
652
653 def __hash__(self):
654 return hash(tuple(x for x in self))
655
656
657def _bind_or_error(schemaitem, msg=None):
658 bind = schemaitem.bind
659 if not bind:
660 name = schemaitem.__class__.__name__
661 label = getattr(
662 schemaitem, "fullname", getattr(schemaitem, "name", None)
663 )
664 if label:
665 item = "%s object %r" % (name, label)
666 else:
667 item = "%s object" % name
668 if msg is None:
669 msg = (
670 "%s is not bound to an Engine or Connection. "
671 "Execution can not proceed without a database to execute "
672 "against." % item
673 )
674 raise exc.UnboundExecutionError(msg)
675 return bind