Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/serialization/serde.py: 40%
163 statements
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
« prev ^ index » next coverage.py v7.0.1, created at 2022-12-25 06:11 +0000
1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
20import dataclasses
21import enum
22import logging
23import re
24import sys
25from importlib import import_module
26from types import ModuleType
27from typing import Any, TypeVar, Union
29import attr
31import airflow.serialization.serializers
32from airflow.configuration import conf
33from airflow.utils.module_loading import import_string, iter_namespace, qualname
35log = logging.getLogger(__name__)
37MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
39CLASSNAME = "__classname__"
40VERSION = "__version__"
41DATA = "__data__"
42SCHEMA_ID = "__id__"
43CACHE = "__cache__"
45OLD_TYPE = "__type"
46OLD_SOURCE = "__source"
47OLD_DATA = "__var"
49DEFAULT_VERSION = 0
51T = TypeVar("T", bool, float, int, dict, list, str, tuple, set)
52U = Union[bool, float, int, dict, list, str, tuple, set]
53S = Union[list, tuple, set]
55_serializers: dict[str, ModuleType] = {}
56_deserializers: dict[str, ModuleType] = {}
57_extra_allowed: set[str] = set()
59_primitives = (int, bool, float, str)
60_iterables = (list, set, tuple)
61_patterns: list[re.Pattern] = []
63_reverse_cache: dict[int, tuple[ModuleType, str, int]] = {}
66def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]:
67 """Encodes o so it can be understood by the deserializer"""
68 return {CLASSNAME: cls, VERSION: version, DATA: data}
71def decode(d: dict[str, str | int | T]) -> tuple:
72 return d[CLASSNAME], d[VERSION], d.get(DATA, None)
75def serialize(o: object, depth: int = 0) -> U | None:
76 """
77 Recursively serializes objects into a primitive. Primitives (int, float, int, bool)
78 are returned as is. Tuples and dicts are iterated over, where it is assumed that keys
79 for dicts can be represented as str. Values that are not primitives are serialized if
80 a serializer is found for them. The order in which serializers are used
81 is 1) a serialize function provided by the object 2) a registered serializer in
82 the namespace of airflow.serialization.serializers and 3) an attr or dataclass annotations.
83 If a serializer cannot be found a TypeError is raised.
85 :param o: object to serialize
86 :param depth: private
87 :return: a primitive
88 """
89 if depth == MAX_RECURSION_DEPTH:
90 raise RecursionError("maximum recursion depth reached for serialization")
92 # None remains None
93 if o is None:
94 return o
96 # primitive types are returned as is
97 if isinstance(o, _primitives):
98 if isinstance(o, enum.Enum):
99 return o.value
101 return o
103 # tuples and plain dicts are iterated over recursively
104 if isinstance(o, _iterables):
105 s = [serialize(d, depth + 1) for d in o]
106 if isinstance(o, tuple):
107 return tuple(s)
108 if isinstance(o, set):
109 return set(s)
110 return s
112 if isinstance(o, dict):
113 if CLASSNAME in o or SCHEMA_ID in o:
114 raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize")
116 return {str(k): serialize(v, depth + 1) for k, v in o.items()}
118 cls = type(o)
119 qn = qualname(o)
121 # custom serializers
122 dct = {
123 CLASSNAME: qn,
124 VERSION: getattr(cls, "__version__", DEFAULT_VERSION),
125 }
127 # if there is a builtin serializer available use that
128 if qn in _serializers:
129 data, classname, version, is_serialized = _serializers[qn].serialize(o)
130 if is_serialized:
131 return encode(classname, version, serialize(data, depth + 1))
133 # object / class brings their own
134 if hasattr(o, "serialize"):
135 data = getattr(o, "serialize")()
137 # if we end up with a structure, ensure its values are serialized
138 if isinstance(data, dict):
139 data = serialize(data, depth + 1)
141 dct[DATA] = data
142 return dct
144 # dataclasses
145 if dataclasses.is_dataclass(cls):
146 data = dataclasses.asdict(o)
147 dct[DATA] = serialize(data, depth + 1)
148 return dct
150 # attr annotated
151 if attr.has(cls):
152 # Only include attributes which we can pass back to the classes constructor
153 data = attr.asdict(o, recurse=True, filter=lambda a, v: a.init) # type: ignore[arg-type]
154 dct[DATA] = serialize(data, depth + 1)
155 return dct
157 raise TypeError(f"cannot serialize object of type {cls}")
160def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:
161 """
162 Deserializes an object of primitive type T into an object. Uses an allow
163 list to determine if a class can be loaded.
165 :param o: primitive to deserialize into an arbitrary object.
166 :param full: if False it will return a stringified representation
167 of an object and will not load any classes
168 :param type_hint: if set it will be used to help determine what
169 object to deserialize in. It does not override if another
170 specification is found
171 :return: object
172 """
173 if o is None:
174 return o
176 if isinstance(o, _primitives):
177 return o
179 if isinstance(o, _iterables):
180 return [deserialize(d) for d in o]
182 if not isinstance(o, dict):
183 raise TypeError()
185 o = _convert(o)
187 # plain dict and no type hint
188 if CLASSNAME not in o and not type_hint or VERSION not in o:
189 return {str(k): deserialize(v, full) for k, v in o.items()}
191 # custom deserialization starts here
192 cls: Any
193 version = 0
194 value: Any
195 classname: str
197 if type_hint:
198 cls = type_hint
199 classname = qualname(cls)
200 version = 0 # type hinting always sets version to 0
201 value = o
203 if CLASSNAME in o and VERSION in o:
204 classname, version, value = decode(o)
205 if not _match(classname) and classname not in _extra_allowed:
206 raise ImportError(
207 f"{classname} was not found in allow list for deserialization imports."
208 f"To allow it, add it to allowed_deserialization_classes in the configuration"
209 )
211 if full:
212 cls = import_string(classname)
214 # only return string representation
215 if not full:
216 return _stringify(classname, version, value)
218 # registered deserializer
219 if classname in _deserializers:
220 return _deserializers[classname].deserialize(classname, version, deserialize(value))
222 # class has deserialization function
223 if hasattr(cls, "deserialize"):
224 return getattr(cls, "deserialize")(deserialize(value), version)
226 # attr or dataclass
227 if attr.has(cls) or dataclasses.is_dataclass(cls):
228 class_version = getattr(cls, "__version__", 0)
229 if int(version) > class_version:
230 raise TypeError(
231 "serialized version of %s is newer than module version (%s > %s)",
232 classname,
233 version,
234 class_version,
235 )
237 return cls(**deserialize(value))
239 # no deserializer available
240 raise TypeError(f"No deserializer found for {classname}")
243def _convert(old: dict) -> dict:
244 """Converts an old style serialization to new style"""
245 if OLD_TYPE in old and OLD_DATA in old:
246 return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA][OLD_DATA]}
248 return old
251def _match(classname: str) -> bool:
252 for p in _patterns:
253 if p.match(classname):
254 return True
256 return False
259def _stringify(classname: str, version: int, value: T | None) -> str:
260 s = f"{classname}@version={version}("
261 if isinstance(value, _primitives):
262 s += f"{value})"
263 elif isinstance(value, _iterables):
264 s += ",".join(str(serialize(value, False)))
265 elif isinstance(value, dict):
266 for k, v in value.items():
267 s += f"{k}={str(serialize(v, False))},"
268 s = s[:-1] + ")"
270 return s
273def _register():
274 """Register builtin serializers and deserializers for types that don't have any themselves"""
275 _serializers.clear()
276 _deserializers.clear()
278 for _, name, _ in iter_namespace(airflow.serialization.serializers):
279 name = import_module(name)
280 for s in getattr(name, "serializers", list()):
281 if not isinstance(s, str):
282 s = qualname(s)
283 if s in _serializers and _serializers[s] != name:
284 raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}")
285 log.debug("registering %s for serialization")
286 _serializers[s] = name
287 for d in getattr(name, "deserializers", list()):
288 if not isinstance(d, str):
289 d = qualname(d)
290 if d in _deserializers and _deserializers[d] != name:
291 raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}")
292 log.debug("registering %s for deserialization", d)
293 _deserializers[d] = name
294 _extra_allowed.add(d)
297def _compile_patterns():
298 patterns = conf.get("core", "allowed_deserialization_classes").split()
300 _patterns.clear() # ensure to reinit
301 for p in patterns:
302 _patterns.append(re.compile(p))
305_register()
306_compile_patterns()