1# sql/base.py 
    2# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors 
    3# <see AUTHORS file> 
    4# 
    5# This module is part of SQLAlchemy and is released under 
    6# the MIT License: http://www.opensource.org/licenses/mit-license.php 
    7 
    8"""Foundational utilities common to many sql modules. 
    9 
    10""" 
    11 
    12 
    13import itertools 
    14import re 
    15 
    16from .visitors import ClauseVisitor 
    17from .. import exc 
    18from .. import util 
    19 
    20 
    21PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT") 
    22NO_ARG = util.symbol("NO_ARG") 
    23 
    24 
    25class Immutable(object): 
    26    """mark a ClauseElement as 'immutable' when expressions are cloned.""" 
    27 
    28    def unique_params(self, *optionaldict, **kwargs): 
    29        raise NotImplementedError("Immutable objects do not support copying") 
    30 
    31    def params(self, *optionaldict, **kwargs): 
    32        raise NotImplementedError("Immutable objects do not support copying") 
    33 
    34    def _clone(self): 
    35        return self 
    36 
    37 
    38def _from_objects(*elements): 
    39    return itertools.chain(*[element._from_objects for element in elements]) 
    40 
    41 
    42@util.decorator 
    43def _generative(fn, *args, **kw): 
    44    """Mark a method as generative.""" 
    45 
    46    self = args[0]._generate() 
    47    fn(self, *args[1:], **kw) 
    48    return self 
    49 
    50 
    51class _DialectArgView(util.collections_abc.MutableMapping): 
    52    """A dictionary view of dialect-level arguments in the form 
    53    <dialectname>_<argument_name>. 
    54 
    55    """ 
    56 
    57    def __init__(self, obj): 
    58        self.obj = obj 
    59 
    60    def _key(self, key): 
    61        try: 
    62            dialect, value_key = key.split("_", 1) 
    63        except ValueError as err: 
    64            util.raise_(KeyError(key), replace_context=err) 
    65        else: 
    66            return dialect, value_key 
    67 
    68    def __getitem__(self, key): 
    69        dialect, value_key = self._key(key) 
    70 
    71        try: 
    72            opt = self.obj.dialect_options[dialect] 
    73        except exc.NoSuchModuleError as err: 
    74            util.raise_(KeyError(key), replace_context=err) 
    75        else: 
    76            return opt[value_key] 
    77 
    78    def __setitem__(self, key, value): 
    79        try: 
    80            dialect, value_key = self._key(key) 
    81        except KeyError as err: 
    82            util.raise_( 
    83                exc.ArgumentError( 
    84                    "Keys must be of the form <dialectname>_<argname>" 
    85                ), 
    86                replace_context=err, 
    87            ) 
    88        else: 
    89            self.obj.dialect_options[dialect][value_key] = value 
    90 
    91    def __delitem__(self, key): 
    92        dialect, value_key = self._key(key) 
    93        del self.obj.dialect_options[dialect][value_key] 
    94 
    95    def __len__(self): 
    96        return sum( 
    97            len(args._non_defaults) 
    98            for args in self.obj.dialect_options.values() 
    99        ) 
    100 
    101    def __iter__(self): 
    102        return ( 
    103            util.safe_kwarg("%s_%s" % (dialect_name, value_name)) 
    104            for dialect_name in self.obj.dialect_options 
    105            for value_name in self.obj.dialect_options[ 
    106                dialect_name 
    107            ]._non_defaults 
    108        ) 
    109 
    110 
    111class _DialectArgDict(util.collections_abc.MutableMapping): 
    112    """A dictionary view of dialect-level arguments for a specific 
    113    dialect. 
    114 
    115    Maintains a separate collection of user-specified arguments 
    116    and dialect-specified default arguments. 
    117 
    118    """ 
    119 
    120    def __init__(self): 
    121        self._non_defaults = {} 
    122        self._defaults = {} 
    123 
    124    def __len__(self): 
    125        return len(set(self._non_defaults).union(self._defaults)) 
    126 
    127    def __iter__(self): 
    128        return iter(set(self._non_defaults).union(self._defaults)) 
    129 
    130    def __getitem__(self, key): 
    131        if key in self._non_defaults: 
    132            return self._non_defaults[key] 
    133        else: 
    134            return self._defaults[key] 
    135 
    136    def __setitem__(self, key, value): 
    137        self._non_defaults[key] = value 
    138 
    139    def __delitem__(self, key): 
    140        del self._non_defaults[key] 
    141 
    142 
    143class DialectKWArgs(object): 
    144    """Establish the ability for a class to have dialect-specific arguments 
    145    with defaults and constructor validation. 
    146 
    147    The :class:`.DialectKWArgs` interacts with the 
    148    :attr:`.DefaultDialect.construct_arguments` present on a dialect. 
    149 
    150    .. seealso:: 
    151 
    152        :attr:`.DefaultDialect.construct_arguments` 
    153 
    154    """ 
    155 
    156    @classmethod 
    157    def argument_for(cls, dialect_name, argument_name, default): 
    158        """Add a new kind of dialect-specific keyword argument for this class. 
    159 
    160        E.g.:: 
    161 
    162            Index.argument_for("mydialect", "length", None) 
    163 
    164            some_index = Index('a', 'b', mydialect_length=5) 
    165 
    166        The :meth:`.DialectKWArgs.argument_for` method is a per-argument 
    167        way adding extra arguments to the 
    168        :attr:`.DefaultDialect.construct_arguments` dictionary. This 
    169        dictionary provides a list of argument names accepted by various 
    170        schema-level constructs on behalf of a dialect. 
    171 
    172        New dialects should typically specify this dictionary all at once as a 
    173        data member of the dialect class.  The use case for ad-hoc addition of 
    174        argument names is typically for end-user code that is also using 
    175        a custom compilation scheme which consumes the additional arguments. 
    176 
    177        :param dialect_name: name of a dialect.  The dialect must be 
    178         locatable, else a :class:`.NoSuchModuleError` is raised.   The 
    179         dialect must also include an existing 
    180         :attr:`.DefaultDialect.construct_arguments` collection, indicating 
    181         that it participates in the keyword-argument validation and default 
    182         system, else :class:`.ArgumentError` is raised.  If the dialect does 
    183         not include this collection, then any keyword argument can be 
    184         specified on behalf of this dialect already.  All dialects packaged 
    185         within SQLAlchemy include this collection, however for third party 
    186         dialects, support may vary. 
    187 
    188        :param argument_name: name of the parameter. 
    189 
    190        :param default: default value of the parameter. 
    191 
    192        .. versionadded:: 0.9.4 
    193 
    194        """ 
    195 
    196        construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] 
    197        if construct_arg_dictionary is None: 
    198            raise exc.ArgumentError( 
    199                "Dialect '%s' does have keyword-argument " 
    200                "validation and defaults enabled configured" % dialect_name 
    201            ) 
    202        if cls not in construct_arg_dictionary: 
    203            construct_arg_dictionary[cls] = {} 
    204        construct_arg_dictionary[cls][argument_name] = default 
    205 
    206    @util.memoized_property 
    207    def dialect_kwargs(self): 
    208        """A collection of keyword arguments specified as dialect-specific 
    209        options to this construct. 
    210 
    211        The arguments are present here in their original ``<dialect>_<kwarg>`` 
    212        format.  Only arguments that were actually passed are included; 
    213        unlike the :attr:`.DialectKWArgs.dialect_options` collection, which 
    214        contains all options known by this dialect including defaults. 
    215 
    216        The collection is also writable; keys are accepted of the 
    217        form ``<dialect>_<kwarg>`` where the value will be assembled 
    218        into the list of options. 
    219 
    220        .. versionadded:: 0.9.2 
    221 
    222        .. versionchanged:: 0.9.4 The :attr:`.DialectKWArgs.dialect_kwargs` 
    223           collection is now writable. 
    224 
    225        .. seealso:: 
    226 
    227            :attr:`.DialectKWArgs.dialect_options` - nested dictionary form 
    228 
    229        """ 
    230        return _DialectArgView(self) 
    231 
    232    @property 
    233    def kwargs(self): 
    234        """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`.""" 
    235        return self.dialect_kwargs 
    236 
    237    @util.dependencies("sqlalchemy.dialects") 
    238    def _kw_reg_for_dialect(dialects, dialect_name): 
    239        dialect_cls = dialects.registry.load(dialect_name) 
    240        if dialect_cls.construct_arguments is None: 
    241            return None 
    242        return dict(dialect_cls.construct_arguments) 
    243 
    244    _kw_registry = util.PopulateDict(_kw_reg_for_dialect) 
    245 
    246    def _kw_reg_for_dialect_cls(self, dialect_name): 
    247        construct_arg_dictionary = DialectKWArgs._kw_registry[dialect_name] 
    248        d = _DialectArgDict() 
    249 
    250        if construct_arg_dictionary is None: 
    251            d._defaults.update({"*": None}) 
    252        else: 
    253            for cls in reversed(self.__class__.__mro__): 
    254                if cls in construct_arg_dictionary: 
    255                    d._defaults.update(construct_arg_dictionary[cls]) 
    256        return d 
    257 
    258    @util.memoized_property 
    259    def dialect_options(self): 
    260        """A collection of keyword arguments specified as dialect-specific 
    261        options to this construct. 
    262 
    263        This is a two-level nested registry, keyed to ``<dialect_name>`` 
    264        and ``<argument_name>``.  For example, the ``postgresql_where`` 
    265        argument would be locatable as:: 
    266 
    267            arg = my_object.dialect_options['postgresql']['where'] 
    268 
    269        .. versionadded:: 0.9.2 
    270 
    271        .. seealso:: 
    272 
    273            :attr:`.DialectKWArgs.dialect_kwargs` - flat dictionary form 
    274 
    275        """ 
    276 
    277        return util.PopulateDict( 
    278            util.portable_instancemethod(self._kw_reg_for_dialect_cls) 
    279        ) 
    280 
    281    def _validate_dialect_kwargs(self, kwargs): 
    282        # validate remaining kwargs that they all specify DB prefixes 
    283 
    284        if not kwargs: 
    285            return 
    286 
    287        for k in kwargs: 
    288            m = re.match("^(.+?)_(.+)$", k) 
    289            if not m: 
    290                raise TypeError( 
    291                    "Additional arguments should be " 
    292                    "named <dialectname>_<argument>, got '%s'" % k 
    293                ) 
    294            dialect_name, arg_name = m.group(1, 2) 
    295 
    296            try: 
    297                construct_arg_dictionary = self.dialect_options[dialect_name] 
    298            except exc.NoSuchModuleError: 
    299                util.warn( 
    300                    "Can't validate argument %r; can't " 
    301                    "locate any SQLAlchemy dialect named %r" 
    302                    % (k, dialect_name) 
    303                ) 
    304                self.dialect_options[dialect_name] = d = _DialectArgDict() 
    305                d._defaults.update({"*": None}) 
    306                d._non_defaults[arg_name] = kwargs[k] 
    307            else: 
    308                if ( 
    309                    "*" not in construct_arg_dictionary 
    310                    and arg_name not in construct_arg_dictionary 
    311                ): 
    312                    raise exc.ArgumentError( 
    313                        "Argument %r is not accepted by " 
    314                        "dialect %r on behalf of %r" 
    315                        % (k, dialect_name, self.__class__) 
    316                    ) 
    317                else: 
    318                    construct_arg_dictionary[arg_name] = kwargs[k] 
    319 
    320 
    321class Generative(object): 
    322    """Allow a ClauseElement to generate itself via the 
    323    @_generative decorator. 
    324 
    325    """ 
    326 
    327    def _generate(self): 
    328        s = self.__class__.__new__(self.__class__) 
    329        s.__dict__ = self.__dict__.copy() 
    330        return s 
    331 
    332 
    333class Executable(Generative): 
    334    """Mark a :class:`_expression.ClauseElement` as supporting execution. 
    335 
    336    :class:`.Executable` is a superclass for all "statement" types 
    337    of objects, including :func:`select`, :func:`delete`, :func:`update`, 
    338    :func:`insert`, :func:`text`. 
    339 
    340    """ 
    341 
    342    supports_execution = True 
    343    _execution_options = util.immutabledict() 
    344    _bind = None 
    345 
    346    @_generative 
    347    def execution_options(self, **kw): 
    348        """Set non-SQL options for the statement which take effect during 
    349        execution. 
    350 
    351        Execution options can be set on a per-statement or 
    352        per :class:`_engine.Connection` basis.   Additionally, the 
    353        :class:`_engine.Engine` and ORM :class:`~.orm.query.Query` 
    354        objects provide 
    355        access to execution options which they in turn configure upon 
    356        connections. 
    357 
    358        The :meth:`execution_options` method is generative.  A new 
    359        instance of this statement is returned that contains the options:: 
    360 
    361            statement = select([table.c.x, table.c.y]) 
    362            statement = statement.execution_options(autocommit=True) 
    363 
    364        Note that only a subset of possible execution options can be applied 
    365        to a statement - these include "autocommit" and "stream_results", 
    366        but not "isolation_level" or "compiled_cache". 
    367        See :meth:`_engine.Connection.execution_options` for a full list of 
    368        possible options. 
    369 
    370        .. seealso:: 
    371 
    372            :meth:`_engine.Connection.execution_options` 
    373 
    374            :meth:`_query.Query.execution_options` 
    375 
    376            :meth:`.Executable.get_execution_options` 
    377 
    378        """ 
    379        if "isolation_level" in kw: 
    380            raise exc.ArgumentError( 
    381                "'isolation_level' execution option may only be specified " 
    382                "on Connection.execution_options(), or " 
    383                "per-engine using the isolation_level " 
    384                "argument to create_engine()." 
    385            ) 
    386        if "compiled_cache" in kw: 
    387            raise exc.ArgumentError( 
    388                "'compiled_cache' execution option may only be specified " 
    389                "on Connection.execution_options(), not per statement." 
    390            ) 
    391        self._execution_options = self._execution_options.union(kw) 
    392 
    393    def get_execution_options(self): 
    394        """Get the non-SQL options which will take effect during execution. 
    395 
    396        .. versionadded:: 1.3 
    397 
    398        .. seealso:: 
    399 
    400            :meth:`.Executable.execution_options` 
    401 
    402        """ 
    403        return self._execution_options 
    404 
    405    def execute(self, *multiparams, **params): 
    406        """Compile and execute this :class:`.Executable`.""" 
    407        e = self.bind 
    408        if e is None: 
    409            label = getattr(self, "description", self.__class__.__name__) 
    410            msg = ( 
    411                "This %s is not directly bound to a Connection or Engine. " 
    412                "Use the .execute() method of a Connection or Engine " 
    413                "to execute this construct." % label 
    414            ) 
    415            raise exc.UnboundExecutionError(msg) 
    416        return e._execute_clauseelement(self, multiparams, params) 
    417 
    418    def scalar(self, *multiparams, **params): 
    419        """Compile and execute this :class:`.Executable`, returning the 
    420        result's scalar representation. 
    421 
    422        """ 
    423        return self.execute(*multiparams, **params).scalar() 
    424 
    425    @property 
    426    def bind(self): 
    427        """Returns the :class:`_engine.Engine` or :class:`_engine.Connection` 
    428        to which this :class:`.Executable` is bound, or None if none found. 
    429 
    430        This is a traversal which checks locally, then 
    431        checks among the "from" clauses of associated objects 
    432        until a bound engine or connection is found. 
    433 
    434        """ 
    435        if self._bind is not None: 
    436            return self._bind 
    437 
    438        for f in _from_objects(self): 
    439            if f is self: 
    440                continue 
    441            engine = f.bind 
    442            if engine is not None: 
    443                return engine 
    444        else: 
    445            return None 
    446 
    447 
    448class SchemaEventTarget(object): 
    449    """Base class for elements that are the targets of :class:`.DDLEvents` 
    450    events. 
    451 
    452    This includes :class:`.SchemaItem` as well as :class:`.SchemaType`. 
    453 
    454    """ 
    455 
    456    def _set_parent(self, parent, **kw): 
    457        """Associate with this SchemaEvent's parent object.""" 
    458 
    459    def _set_parent_with_dispatch(self, parent, **kw): 
    460        self.dispatch.before_parent_attach(self, parent) 
    461        self._set_parent(parent, **kw) 
    462        self.dispatch.after_parent_attach(self, parent) 
    463 
    464 
    465class SchemaVisitor(ClauseVisitor): 
    466    """Define the visiting for ``SchemaItem`` objects.""" 
    467 
    468    __traverse_options__ = {"schema_visitor": True} 
    469 
    470 
    471class ColumnCollection(util.OrderedProperties): 
    472    """An ordered dictionary that stores a list of ColumnElement 
    473    instances. 
    474 
    475    Overrides the ``__eq__()`` method to produce SQL clauses between 
    476    sets of correlated columns. 
    477 
    478    """ 
    479 
    480    __slots__ = "_all_columns" 
    481 
    482    def __init__(self, *columns): 
    483        super(ColumnCollection, self).__init__() 
    484        object.__setattr__(self, "_all_columns", []) 
    485        for c in columns: 
    486            self.add(c) 
    487 
    488    def __str__(self): 
    489        return repr([str(c) for c in self]) 
    490 
    491    def replace(self, column): 
    492        """Add the given column to this collection, removing unaliased 
    493        versions of this column  as well as existing columns with the 
    494        same key. 
    495 
    496        E.g.:: 
    497 
    498             t = Table('sometable', metadata, Column('col1', Integer)) 
    499             t.columns.replace(Column('col1', Integer, key='columnone')) 
    500 
    501        will remove the original 'col1' from the collection, and add 
    502        the new column under the name 'columnname'. 
    503 
    504        Used by schema.Column to override columns during table reflection. 
    505 
    506        """ 
    507        remove_col = None 
    508        if column.name in self and column.key != column.name: 
    509            other = self[column.name] 
    510            if other.name == other.key: 
    511                remove_col = other 
    512                del self._data[other.key] 
    513 
    514        if column.key in self._data: 
    515            remove_col = self._data[column.key] 
    516 
    517        self._data[column.key] = column 
    518        if remove_col is not None: 
    519            self._all_columns[:] = [ 
    520                column if c is remove_col else c for c in self._all_columns 
    521            ] 
    522        else: 
    523            self._all_columns.append(column) 
    524 
    525    def add(self, column): 
    526        """Add a column to this collection. 
    527 
    528        The key attribute of the column will be used as the hash key 
    529        for this dictionary. 
    530 
    531        """ 
    532        if not column.key: 
    533            raise exc.ArgumentError( 
    534                "Can't add unnamed column to column collection" 
    535            ) 
    536        self[column.key] = column 
    537 
    538    def __delitem__(self, key): 
    539        raise NotImplementedError() 
    540 
    541    def __setattr__(self, key, obj): 
    542        raise NotImplementedError() 
    543 
    544    def __setitem__(self, key, value): 
    545        if key in self: 
    546 
    547            # this warning is primarily to catch select() statements 
    548            # which have conflicting column names in their exported 
    549            # columns collection 
    550 
    551            existing = self[key] 
    552 
    553            if existing is value: 
    554                return 
    555 
    556            if not existing.shares_lineage(value): 
    557                util.warn( 
    558                    "Column %r on table %r being replaced by " 
    559                    "%r, which has the same key.  Consider " 
    560                    "use_labels for select() statements." 
    561                    % (key, getattr(existing, "table", None), value) 
    562                ) 
    563 
    564            # pop out memoized proxy_set as this 
    565            # operation may very well be occurring 
    566            # in a _make_proxy operation 
    567            util.memoized_property.reset(value, "proxy_set") 
    568 
    569        self._all_columns.append(value) 
    570        self._data[key] = value 
    571 
    572    def clear(self): 
    573        raise NotImplementedError() 
    574 
    575    def remove(self, column): 
    576        del self._data[column.key] 
    577        self._all_columns[:] = [ 
    578            c for c in self._all_columns if c is not column 
    579        ] 
    580 
    581    def update(self, iter_): 
    582        cols = list(iter_) 
    583        all_col_set = set(self._all_columns) 
    584        self._all_columns.extend( 
    585            c for label, c in cols if c not in all_col_set 
    586        ) 
    587        self._data.update((label, c) for label, c in cols) 
    588 
    589    def extend(self, iter_): 
    590        cols = list(iter_) 
    591        all_col_set = set(self._all_columns) 
    592        self._all_columns.extend(c for c in cols if c not in all_col_set) 
    593        self._data.update((c.key, c) for c in cols) 
    594 
    595    __hash__ = None 
    596 
    597    @util.dependencies("sqlalchemy.sql.elements") 
    598    def __eq__(self, elements, other): 
    599        l = [] 
    600        for c in getattr(other, "_all_columns", other): 
    601            for local in self._all_columns: 
    602                if c.shares_lineage(local): 
    603                    l.append(c == local) 
    604        return elements.and_(*l) 
    605 
    606    def __contains__(self, other): 
    607        if not isinstance(other, util.string_types): 
    608            raise exc.ArgumentError("__contains__ requires a string argument") 
    609        return util.OrderedProperties.__contains__(self, other) 
    610 
    611    def __getstate__(self): 
    612        return {"_data": self._data, "_all_columns": self._all_columns} 
    613 
    614    def __setstate__(self, state): 
    615        object.__setattr__(self, "_data", state["_data"]) 
    616        object.__setattr__(self, "_all_columns", state["_all_columns"]) 
    617 
    618    def contains_column(self, col): 
    619        return col in set(self._all_columns) 
    620 
    621    def as_immutable(self): 
    622        return ImmutableColumnCollection(self._data, self._all_columns) 
    623 
    624 
    625class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection): 
    626    def __init__(self, data, all_columns): 
    627        util.ImmutableProperties.__init__(self, data) 
    628        object.__setattr__(self, "_all_columns", all_columns) 
    629 
    630    extend = remove = util.ImmutableProperties._immutable 
    631 
    632 
    633class ColumnSet(util.ordered_column_set): 
    634    def contains_column(self, col): 
    635        return col in self 
    636 
    637    def extend(self, cols): 
    638        for col in cols: 
    639            self.add(col) 
    640 
    641    def __add__(self, other): 
    642        return list(self) + list(other) 
    643 
    644    @util.dependencies("sqlalchemy.sql.elements") 
    645    def __eq__(self, elements, other): 
    646        l = [] 
    647        for c in other: 
    648            for local in self: 
    649                if c.shares_lineage(local): 
    650                    l.append(c == local) 
    651        return elements.and_(*l) 
    652 
    653    def __hash__(self): 
    654        return hash(tuple(x for x in self)) 
    655 
    656 
    657def _bind_or_error(schemaitem, msg=None): 
    658    bind = schemaitem.bind 
    659    if not bind: 
    660        name = schemaitem.__class__.__name__ 
    661        label = getattr( 
    662            schemaitem, "fullname", getattr(schemaitem, "name", None) 
    663        ) 
    664        if label: 
    665            item = "%s object %r" % (name, label) 
    666        else: 
    667            item = "%s object" % name 
    668        if msg is None: 
    669            msg = ( 
    670                "%s is not bound to an Engine or Connection.  " 
    671                "Execution can not proceed without a database to execute " 
    672                "against." % item 
    673            ) 
    674        raise exc.UnboundExecutionError(msg) 
    675    return bind