1#    Copyright 2017-2022 Alexey Stepanov aka penguinolog 
    2 
    3#    Licensed under the Apache License, Version 2.0 (the "License"); you may 
    4#    not use this file except in compliance with the License. You may obtain 
    5#    a copy of the License at 
    6 
    7#         http://www.apache.org/licenses/LICENSE-2.0 
    8 
    9#    Unless required by applicable law or agreed to in writing, software 
    10#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 
    11#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 
    12#    License for the specific language governing permissions and limitations 
    13#    under the License. 
    14 
    15"""JSONField implementation for SQLAlchemy.""" 
    16 
    17from __future__ import annotations 
    18 
    19import json 
    20import typing 
    21 
    22import sqlalchemy.ext.mutable 
    23import sqlalchemy.types 
    24 
    25if typing.TYPE_CHECKING: 
    26    import types 
    27 
    28    from sqlalchemy.engine import Dialect 
    29    from sqlalchemy.sql.type_api import TypeEngine 
    30 
    31__all__ = ("JSONField", "mutable_json_field") 
    32 
    33 
    34# noinspection PyAbstractClass 
    35class JSONField(sqlalchemy.types.TypeDecorator):  # type: ignore[type-arg]  # pylint: disable=abstract-method 
    36    """Represent an immutable structure as a json-encoded string or json. 
    37 
    38    Usage:: 
    39 
    40        JSONField(enforce_string=True|False, enforce_unicode=True|False) 
    41 
    42    """ 
    43 
    44    def process_literal_param(self, value: typing.Any, dialect: Dialect) -> typing.Any: 
    45        """Re-use of process_bind_param. 
    46 
    47        :return: encoded value if required 
    48        :rtype: typing.Union[str, typing.Any] 
    49        """ 
    50        return self.process_bind_param(value, dialect) 
    51 
    52    impl = sqlalchemy.types.TypeEngine  # Special placeholder 
    53    cache_ok = False  # Cache complexity due to requerement of value re-serialization and mutability 
    54 
    55    def __init__(  # pylint: disable=keyword-arg-before-vararg 
    56        self, 
    57        enforce_string: bool = False, 
    58        enforce_unicode: bool = False, 
    59        json: types.ModuleType | typing.Any = json,  # pylint: disable=redefined-outer-name 
    60        json_type: TypeEngine[typing.Any] | type[TypeEngine[typing.Any]] = sqlalchemy.JSON, 
    61        *args: typing.Any, 
    62        **kwargs: typing.Any, 
    63    ) -> None: 
    64        """JSONField. 
    65 
    66        :param enforce_string: enforce String(UnicodeText) type usage 
    67        :type enforce_string: bool 
    68        :param enforce_unicode: do not encode non-ascii data 
    69        :type enforce_unicode: bool 
    70        :param json: JSON encoding/decoding library. By default: standard json package. 
    71        :param json_type: the sqlalchemy/dialect class that will be used to render the DB JSON type. 
    72                          By default: sqlalchemy.JSON 
    73        :param args: extra baseclass arguments 
    74        :type args: typing.Any 
    75        :param kwargs: extra baseclass keyworded arguments 
    76        :type kwargs: typing.Any 
    77        """ 
    78        self.__enforce_string = enforce_string 
    79        self.__enforce_unicode = enforce_unicode 
    80        self.__json_codec = json 
    81        self.__json_type = json_type 
    82        super().__init__(*args, **kwargs) 
    83 
    84    def __use_json(self, dialect: Dialect) -> bool: 
    85        """Helper to determine, which encoder to use. 
    86 
    87        :return: use engine-based json encoder 
    88        :rtype: bool 
    89        """ 
    90        return hasattr(dialect, "_json_serializer") and not self.__enforce_string 
    91 
    92    def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[typing.Any]: 
    93        """Select impl by dialect. 
    94 
    95        :return: dialect implementation depends on decoding method 
    96        :rtype: TypeEngine 
    97        """ 
    98        # types are handled by DefaultDialect, Dialect class is abstract 
    99        if self.__use_json(dialect): 
    100            return dialect.type_descriptor(self.__json_type)  # type: ignore[arg-type] 
    101        return dialect.type_descriptor(sqlalchemy.UnicodeText)  # type: ignore[arg-type] 
    102 
    103    def process_bind_param(self, value: typing.Any, dialect: Dialect) -> str | typing.Any: 
    104        """Encode data, if required. 
    105 
    106        :return: encoded value if required 
    107        :rtype: typing.Union[str, typing.Any] 
    108        """ 
    109        if self.__use_json(dialect) or value is None: 
    110            return value 
    111 
    112        return self.__json_codec.dumps(value, ensure_ascii=not self.__enforce_unicode) 
    113 
    114    def process_result_value(self, value: str | typing.Any, dialect: Dialect) -> typing.Any: 
    115        """Decode data, if required. 
    116 
    117        :return: decoded result value if required 
    118        :rtype: typing.Any 
    119        """ 
    120        if self.__use_json(dialect) or value is None: 
    121            return value 
    122 
    123        return self.__json_codec.loads(value) 
    124 
    125 
    126def mutable_json_field(  # pylint: disable=keyword-arg-before-vararg, redefined-outer-name 
    127    enforce_string: bool = False, 
    128    enforce_unicode: bool = False, 
    129    json: types.ModuleType | typing.Any = json, 
    130    *args: typing.Any, 
    131    **kwargs: typing.Any, 
    132) -> JSONField: 
    133    """Mutable JSONField creator. 
    134 
    135    :param enforce_string: enforce String(UnicodeText) type usage 
    136    :type enforce_string: bool 
    137    :param enforce_unicode: do not encode non-ascii data 
    138    :type enforce_unicode: bool 
    139    :param json: JSON encoding/decoding library. 
    140                 By default: standard json package. 
    141    :param args: extra baseclass arguments 
    142    :type args: typing.Any 
    143    :param kwargs: extra baseclass keyworded arguments 
    144    :type kwargs: typing.Any 
    145    :return: Mutable JSONField via MutableDict.as_mutable 
    146    :rtype: JSONField 
    147    """ 
    148    return sqlalchemy.ext.mutable.MutableDict.as_mutable(  # type: ignore[return-value] 
    149        JSONField(  # type: ignore[misc] 
    150            enforce_string=enforce_string, 
    151            enforce_unicode=enforce_unicode, 
    152            json=json, 
    153            *args,  # noqa: B026 
    154            **kwargs, 
    155        ) 
    156    )