Coverage for /pythoncovmergedfiles/medio/medio/src/airflow/airflow/serialization/serde.py: 38%
185 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-06-07 06:35 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import dataclasses
21import enum
22import functools
23import logging
24import re
25import sys
26from importlib import import_module
27from types import ModuleType
28from typing import Any, TypeVar, Union, cast
30import attr
32import airflow.serialization.serializers
33from airflow.configuration import conf
34from airflow.stats import Stats
35from airflow.utils.module_loading import import_string, iter_namespace, qualname
37log = logging.getLogger(__name__)
39MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
41CLASSNAME = "__classname__"
42VERSION = "__version__"
43DATA = "__data__"
44SCHEMA_ID = "__id__"
45CACHE = "__cache__"
47OLD_TYPE = "__type"
48OLD_SOURCE = "__source"
49OLD_DATA = "__var"
51DEFAULT_VERSION = 0
53T = TypeVar("T", bool, float, int, dict, list, str, tuple, set)
54U = Union[bool, float, int, dict, list, str, tuple, set]
55S = Union[list, tuple, set]
57_serializers: dict[str, ModuleType] = {}
58_deserializers: dict[str, ModuleType] = {}
59_stringifiers: dict[str, ModuleType] = {}
60_extra_allowed: set[str] = set()
62_primitives = (int, bool, float, str)
63_builtin_collections = (frozenset, list, set, tuple) # dict is treated specially.
66def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]:
67 """Encodes o so it can be understood by the deserializer."""
68 return {CLASSNAME: cls, VERSION: version, DATA: data}
71def decode(d: dict[str, Any]) -> tuple[str, int, Any]:
72 classname = d[CLASSNAME]
73 version = d[VERSION]
75 if not isinstance(classname, str) or not isinstance(version, int):
76 raise ValueError(f"cannot decode {d!r}")
78 data = d.get(DATA)
80 return classname, version, data
83def serialize(o: object, depth: int = 0) -> U | None:
84 """Serialize an object into a representation consisting only built-in types.
86 Primitives (int, float, bool, str) are returned as-is. Built-in collections
87 are iterated over, where it is assumed that keys in a dict can be represented
88 as str.
90 Values that are not of a built-in type are serialized if a serializer is
91 found for them. The order in which serializers are used is
93 1. A ``serialize`` function provided by the object.
94 2. A registered serializer in the namespace of ``airflow.serialization.serializers``
95 3. Annotations from attr or dataclass.
97 Limitations: attr and dataclass objects can lose type information for nested objects
98 as they do not store this when calling ``asdict``. This means that at deserialization values
99 will be deserialized as a dict as opposed to reinstating the object. Provide
100 your own serializer to work around this.
102 :param o: The object to serialize.
103 :param depth: Private tracker for nested serialization.
104 :raise TypeError: A serializer cannot be found.
105 :raise RecursionError: The object is too nested for the function to handle.
106 :return: A representation of ``o`` that consists of only built-in types.
107 """
108 if depth == MAX_RECURSION_DEPTH:
109 raise RecursionError("maximum recursion depth reached for serialization")
111 # None remains None
112 if o is None:
113 return o
115 # primitive types are returned as is
116 if isinstance(o, _primitives):
117 if isinstance(o, enum.Enum):
118 return o.value
120 return o
122 if isinstance(o, list):
123 return [serialize(d, depth + 1) for d in o]
125 if isinstance(o, dict):
126 if CLASSNAME in o or SCHEMA_ID in o:
127 raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize")
129 return {str(k): serialize(v, depth + 1) for k, v in o.items()}
131 cls = type(o)
132 qn = qualname(o)
134 # custom serializers
135 dct = {
136 CLASSNAME: qn,
137 VERSION: getattr(cls, "__version__", DEFAULT_VERSION),
138 }
140 # if there is a builtin serializer available use that
141 if qn in _serializers:
142 data, classname, version, is_serialized = _serializers[qn].serialize(o)
143 if is_serialized:
144 return encode(classname, version, serialize(data, depth + 1))
146 # object / class brings their own
147 if hasattr(o, "serialize"):
148 data = getattr(o, "serialize")()
150 # if we end up with a structure, ensure its values are serialized
151 if isinstance(data, dict):
152 data = serialize(data, depth + 1)
154 dct[DATA] = data
155 return dct
157 # pydantic models are recursive
158 if _is_pydantic(cls):
159 data = o.dict() # type: ignore[attr-defined]
160 dct[DATA] = serialize(data, depth + 1)
161 return dct
163 # dataclasses
164 if dataclasses.is_dataclass(cls):
165 # fixme: unfortunately using asdict with nested dataclasses it looses information
166 data = dataclasses.asdict(o) # type: ignore[call-overload]
167 dct[DATA] = serialize(data, depth + 1)
168 return dct
170 # attr annotated
171 if attr.has(cls):
172 # Only include attributes which we can pass back to the classes constructor
173 data = attr.asdict(cast(attr.AttrsInstance, o), recurse=True, filter=lambda a, v: a.init)
174 dct[DATA] = serialize(data, depth + 1)
175 return dct
177 raise TypeError(f"cannot serialize object of type {cls}")
180def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:
181 """
182 Deserializes an object of primitive type T into an object. Uses an allow
183 list to determine if a class can be loaded.
185 :param o: primitive to deserialize into an arbitrary object.
186 :param full: if False it will return a stringified representation
187 of an object and will not load any classes
188 :param type_hint: if set it will be used to help determine what
189 object to deserialize in. It does not override if another
190 specification is found
191 :return: object
192 """
193 if o is None:
194 return o
196 if isinstance(o, _primitives):
197 return o
199 # tuples, sets are included here for backwards compatibility
200 if isinstance(o, _builtin_collections):
201 col = [deserialize(d) for d in o]
202 if isinstance(o, tuple):
203 return tuple(col)
205 if isinstance(o, set):
206 return set(col)
208 return col
210 if not isinstance(o, dict):
211 # if o is not a dict, then it's already deserialized
212 # in this case we should return it as is
213 return o
215 o = _convert(o)
217 # plain dict and no type hint
218 if CLASSNAME not in o and not type_hint or VERSION not in o:
219 return {str(k): deserialize(v, full) for k, v in o.items()}
221 # custom deserialization starts here
222 cls: Any
223 version = 0
224 value: Any = None
225 classname = ""
227 if type_hint:
228 cls = type_hint
229 classname = qualname(cls)
230 version = 0 # type hinting always sets version to 0
231 value = o
233 if CLASSNAME in o and VERSION in o:
234 classname, version, value = decode(o)
236 if not classname:
237 raise TypeError("classname cannot be empty")
239 # only return string representation
240 if not full:
241 return _stringify(classname, version, value)
243 if not _match(classname) and classname not in _extra_allowed:
244 raise ImportError(
245 f"{classname} was not found in allow list for deserialization imports. "
246 f"To allow it, add it to allowed_deserialization_classes in the configuration"
247 )
249 cls = import_string(classname)
251 # registered deserializer
252 if classname in _deserializers:
253 return _deserializers[classname].deserialize(classname, version, deserialize(value))
255 # class has deserialization function
256 if hasattr(cls, "deserialize"):
257 return getattr(cls, "deserialize")(deserialize(value), version)
259 # attr or dataclass or pydantic
260 if attr.has(cls) or dataclasses.is_dataclass(cls) or _is_pydantic(cls):
261 class_version = getattr(cls, "__version__", 0)
262 if int(version) > class_version:
263 raise TypeError(
264 "serialized version of %s is newer than module version (%s > %s)",
265 classname,
266 version,
267 class_version,
268 )
270 return cls(**deserialize(value))
272 # no deserializer available
273 raise TypeError(f"No deserializer found for {classname}")
276def _convert(old: dict) -> dict:
277 """Converts an old style serialization to new style."""
278 if OLD_TYPE in old and OLD_DATA in old:
279 return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA][OLD_DATA]}
281 return old
284def _match(classname: str) -> bool:
285 return any(p.match(classname) is not None for p in _get_patterns())
288def _stringify(classname: str, version: int, value: T | None) -> str:
289 """Convert a previously serialized object in a somewhat human-readable format.
291 This function is not designed to be exact, and will not extensively traverse
292 the whole tree of an object.
293 """
294 if classname in _stringifiers:
295 return _stringifiers[classname].stringify(classname, version, value)
297 s = f"{classname}@version={version}("
298 if isinstance(value, _primitives):
299 s += f"{value})"
300 elif isinstance(value, _builtin_collections):
301 # deserialized values can be != str
302 s += ",".join(str(deserialize(value, full=False)))
303 elif isinstance(value, dict):
304 for k, v in value.items():
305 s += f"{k}={deserialize(v, full=False)},"
306 s = s[:-1] + ")"
308 return s
311def _is_pydantic(cls: Any) -> bool:
312 """Return True if the class is a pydantic model.
314 Checking is done by attributes as it is significantly faster than
315 using isinstance.
316 """
317 return hasattr(cls, "__validators__") and hasattr(cls, "__fields__") and hasattr(cls, "dict")
320def _register():
321 """Register builtin serializers and deserializers for types that don't have any themselves."""
322 _serializers.clear()
323 _deserializers.clear()
324 _stringifiers.clear()
326 with Stats.timer("serde.load_serializers") as timer:
327 for _, name, _ in iter_namespace(airflow.serialization.serializers):
328 name = import_module(name)
329 for s in getattr(name, "serializers", ()):
330 if not isinstance(s, str):
331 s = qualname(s)
332 if s in _serializers and _serializers[s] != name:
333 raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}")
334 log.debug("registering %s for serialization", s)
335 _serializers[s] = name
336 for d in getattr(name, "deserializers", ()):
337 if not isinstance(d, str):
338 d = qualname(d)
339 if d in _deserializers and _deserializers[d] != name:
340 raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}")
341 log.debug("registering %s for deserialization", d)
342 _deserializers[d] = name
343 _extra_allowed.add(d)
344 for c in getattr(name, "stringifiers", ()):
345 if not isinstance(c, str):
346 c = qualname(c)
347 if c in _deserializers and _deserializers[c] != name:
348 raise AttributeError(f"duplicate {c} for stringifiers in {name} and {_stringifiers[c]}")
349 log.debug("registering %s for stringifying", c)
350 _stringifiers[c] = name
352 log.debug("loading serializers took %.3f seconds", timer.duration)
355@functools.lru_cache(maxsize=None)
356def _get_patterns() -> list[re.Pattern]:
357 patterns = conf.get("core", "allowed_deserialization_classes").split()
358 return [re.compile(re.sub(r"(\w)\.", r"\1\..", p)) for p in patterns]
361_register()