1# Copyright 2017-2022 Alexey Stepanov aka penguinolog
2
3# Licensed under the Apache License, Version 2.0 (the "License"); you may
4# not use this file except in compliance with the License. You may obtain
5# a copy of the License at
6
7# http://www.apache.org/licenses/LICENSE-2.0
8
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12# License for the specific language governing permissions and limitations
13# under the License.
14
15"""JSONField implementation for SQLAlchemy."""
16
17from __future__ import annotations
18
19import json
20import typing
21
22import sqlalchemy.ext.mutable
23import sqlalchemy.types
24
25if typing.TYPE_CHECKING:
26 import types
27
28 from sqlalchemy.engine import Dialect
29 from sqlalchemy.sql.type_api import TypeEngine
30
31__all__ = ("JSONField", "mutable_json_field")
32
33
34# noinspection PyAbstractClass
35class JSONField(sqlalchemy.types.TypeDecorator): # type: ignore[type-arg] # pylint: disable=abstract-method
36 """Represent an immutable structure as a json-encoded string or json.
37
38 Usage::
39
40 JSONField(enforce_string=True|False, enforce_unicode=True|False)
41
42 """
43
44 def process_literal_param(self, value: typing.Any, dialect: Dialect) -> typing.Any:
45 """Re-use of process_bind_param.
46
47 :return: encoded value if required
48 :rtype: typing.Union[str, typing.Any]
49 """
50 return self.process_bind_param(value, dialect)
51
52 impl = sqlalchemy.types.TypeEngine # Special placeholder
53 cache_ok = False # Cache complexity due to requerement of value re-serialization and mutability
54
55 def __init__( # pylint: disable=keyword-arg-before-vararg
56 self,
57 enforce_string: bool = False,
58 enforce_unicode: bool = False,
59 json: types.ModuleType | typing.Any = json, # pylint: disable=redefined-outer-name
60 json_type: TypeEngine[typing.Any] | type[TypeEngine[typing.Any]] = sqlalchemy.JSON,
61 *args: typing.Any,
62 **kwargs: typing.Any,
63 ) -> None:
64 """JSONField.
65
66 :param enforce_string: enforce String(UnicodeText) type usage
67 :type enforce_string: bool
68 :param enforce_unicode: do not encode non-ascii data
69 :type enforce_unicode: bool
70 :param json: JSON encoding/decoding library. By default: standard json package.
71 :param json_type: the sqlalchemy/dialect class that will be used to render the DB JSON type.
72 By default: sqlalchemy.JSON
73 :param args: extra baseclass arguments
74 :type args: typing.Any
75 :param kwargs: extra baseclass keyworded arguments
76 :type kwargs: typing.Any
77 """
78 self.__enforce_string = enforce_string
79 self.__enforce_unicode = enforce_unicode
80 self.__json_codec = json
81 self.__json_type = json_type
82 super().__init__(*args, **kwargs)
83
84 def __use_json(self, dialect: Dialect) -> bool:
85 """Helper to determine, which encoder to use.
86
87 :return: use engine-based json encoder
88 :rtype: bool
89 """
90 return hasattr(dialect, "_json_serializer") and not self.__enforce_string
91
92 def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[typing.Any]:
93 """Select impl by dialect.
94
95 :return: dialect implementation depends on decoding method
96 :rtype: TypeEngine
97 """
98 # types are handled by DefaultDialect, Dialect class is abstract
99 if self.__use_json(dialect):
100 return dialect.type_descriptor(self.__json_type) # type: ignore[arg-type]
101 return dialect.type_descriptor(sqlalchemy.UnicodeText) # type: ignore[arg-type]
102
103 def process_bind_param(self, value: typing.Any, dialect: Dialect) -> str | typing.Any:
104 """Encode data, if required.
105
106 :return: encoded value if required
107 :rtype: typing.Union[str, typing.Any]
108 """
109 if self.__use_json(dialect) or value is None:
110 return value
111
112 return self.__json_codec.dumps(value, ensure_ascii=not self.__enforce_unicode)
113
114 def process_result_value(self, value: str | typing.Any, dialect: Dialect) -> typing.Any:
115 """Decode data, if required.
116
117 :return: decoded result value if required
118 :rtype: typing.Any
119 """
120 if self.__use_json(dialect) or value is None:
121 return value
122
123 return self.__json_codec.loads(value)
124
125
126def mutable_json_field( # pylint: disable=keyword-arg-before-vararg, redefined-outer-name
127 enforce_string: bool = False,
128 enforce_unicode: bool = False,
129 json: types.ModuleType | typing.Any = json,
130 *args: typing.Any,
131 **kwargs: typing.Any,
132) -> JSONField:
133 """Mutable JSONField creator.
134
135 :param enforce_string: enforce String(UnicodeText) type usage
136 :type enforce_string: bool
137 :param enforce_unicode: do not encode non-ascii data
138 :type enforce_unicode: bool
139 :param json: JSON encoding/decoding library.
140 By default: standard json package.
141 :param args: extra baseclass arguments
142 :type args: typing.Any
143 :param kwargs: extra baseclass keyworded arguments
144 :type kwargs: typing.Any
145 :return: Mutable JSONField via MutableDict.as_mutable
146 :rtype: JSONField
147 """
148 return sqlalchemy.ext.mutable.MutableDict.as_mutable( # type: ignore[return-value]
149 JSONField( # type: ignore[misc]
150 enforce_string=enforce_string,
151 enforce_unicode=enforce_unicode,
152 json=json,
153 *args, # noqa: B026
154 **kwargs,
155 )
156 )