1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import dataclasses
21import enum
22import functools
23import logging
24import sys
25from fnmatch import fnmatch
26from importlib import import_module
27from typing import TYPE_CHECKING, Any, Pattern, TypeVar, Union, cast
28
29import attr
30import re2
31
32import airflow.serialization.serializers
33from airflow.configuration import conf
34from airflow.stats import Stats
35from airflow.utils.module_loading import import_string, iter_namespace, qualname
36
37if TYPE_CHECKING:
38 from types import ModuleType
39
40log = logging.getLogger(__name__)
41
42MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
43
44CLASSNAME = "__classname__"
45VERSION = "__version__"
46DATA = "__data__"
47SCHEMA_ID = "__id__"
48CACHE = "__cache__"
49
50OLD_TYPE = "__type"
51OLD_SOURCE = "__source"
52OLD_DATA = "__var"
53OLD_DICT = "dict"
54
55DEFAULT_VERSION = 0
56
57T = TypeVar("T", bool, float, int, dict, list, str, tuple, set)
58U = Union[bool, float, int, dict, list, str, tuple, set]
59S = Union[list, tuple, set]
60
61_serializers: dict[str, ModuleType] = {}
62_deserializers: dict[str, ModuleType] = {}
63_stringifiers: dict[str, ModuleType] = {}
64_extra_allowed: set[str] = set()
65
66_primitives = (int, bool, float, str)
67_builtin_collections = (frozenset, list, set, tuple) # dict is treated specially.
68
69
70def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]:
71 """Encode an object so it can be understood by the deserializer."""
72 return {CLASSNAME: cls, VERSION: version, DATA: data}
73
74
75def decode(d: dict[str, Any]) -> tuple[str, int, Any]:
76 classname = d[CLASSNAME]
77 version = d[VERSION]
78
79 if not isinstance(classname, str) or not isinstance(version, int):
80 raise ValueError(f"cannot decode {d!r}")
81
82 data = d.get(DATA)
83
84 return classname, version, data
85
86
87def serialize(o: object, depth: int = 0) -> U | None:
88 """Serialize an object into a representation consisting only built-in types.
89
90 Primitives (int, float, bool, str) are returned as-is. Built-in collections
91 are iterated over, where it is assumed that keys in a dict can be represented
92 as str.
93
94 Values that are not of a built-in type are serialized if a serializer is
95 found for them. The order in which serializers are used is
96
97 1. A ``serialize`` function provided by the object.
98 2. A registered serializer in the namespace of ``airflow.serialization.serializers``
99 3. Annotations from attr or dataclass.
100
101 Limitations: attr and dataclass objects can lose type information for nested objects
102 as they do not store this when calling ``asdict``. This means that at deserialization values
103 will be deserialized as a dict as opposed to reinstating the object. Provide
104 your own serializer to work around this.
105
106 :param o: The object to serialize.
107 :param depth: Private tracker for nested serialization.
108 :raise TypeError: A serializer cannot be found.
109 :raise RecursionError: The object is too nested for the function to handle.
110 :return: A representation of ``o`` that consists of only built-in types.
111 """
112 if depth == MAX_RECURSION_DEPTH:
113 raise RecursionError("maximum recursion depth reached for serialization")
114
115 # None remains None
116 if o is None:
117 return o
118
119 # primitive types are returned as is
120 if isinstance(o, _primitives):
121 if isinstance(o, enum.Enum):
122 return o.value
123
124 return o
125
126 if isinstance(o, list):
127 return [serialize(d, depth + 1) for d in o]
128
129 if isinstance(o, dict):
130 if CLASSNAME in o or SCHEMA_ID in o:
131 raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize")
132
133 return {str(k): serialize(v, depth + 1) for k, v in o.items()}
134
135 cls = type(o)
136 qn = qualname(o)
137 classname = None
138
139 # Serialize namedtuple like tuples
140 # We also override the classname returned by the builtin.py serializer. The classname
141 # has to be "builtins.tuple", so that the deserializer can deserialize the object into tuple.
142 if _is_namedtuple(o):
143 qn = "builtins.tuple"
144 classname = qn
145
146 # if there is a builtin serializer available use that
147 if qn in _serializers:
148 data, serialized_classname, version, is_serialized = _serializers[qn].serialize(o)
149 if is_serialized:
150 return encode(classname or serialized_classname, version, serialize(data, depth + 1))
151
152 # custom serializers
153 dct = {
154 CLASSNAME: qn,
155 VERSION: getattr(cls, "__version__", DEFAULT_VERSION),
156 }
157
158 # object / class brings their own
159 if hasattr(o, "serialize"):
160 data = getattr(o, "serialize")()
161
162 # if we end up with a structure, ensure its values are serialized
163 if isinstance(data, dict):
164 data = serialize(data, depth + 1)
165
166 dct[DATA] = data
167 return dct
168
169 # pydantic models are recursive
170 if _is_pydantic(cls):
171 data = o.model_dump() # type: ignore[attr-defined]
172 dct[DATA] = serialize(data, depth + 1)
173 return dct
174
175 # dataclasses
176 if dataclasses.is_dataclass(cls):
177 # fixme: unfortunately using asdict with nested dataclasses it looses information
178 data = dataclasses.asdict(o) # type: ignore[call-overload]
179 dct[DATA] = serialize(data, depth + 1)
180 return dct
181
182 # attr annotated
183 if attr.has(cls):
184 # Only include attributes which we can pass back to the classes constructor
185 data = attr.asdict(cast(attr.AttrsInstance, o), recurse=False, filter=lambda a, v: a.init)
186 dct[DATA] = serialize(data, depth + 1)
187 return dct
188
189 raise TypeError(f"cannot serialize object of type {cls}")
190
191
192def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:
193 """
194 Deserialize an object of primitive type and uses an allow list to determine if a class can be loaded.
195
196 :param o: primitive to deserialize into an arbitrary object.
197 :param full: if False it will return a stringified representation
198 of an object and will not load any classes
199 :param type_hint: if set it will be used to help determine what
200 object to deserialize in. It does not override if another
201 specification is found
202 :return: object
203 """
204 if o is None:
205 return o
206
207 if isinstance(o, _primitives):
208 return o
209
210 # tuples, sets are included here for backwards compatibility
211 if isinstance(o, _builtin_collections):
212 col = [deserialize(d) for d in o]
213 if isinstance(o, tuple):
214 return tuple(col)
215
216 if isinstance(o, set):
217 return set(col)
218
219 return col
220
221 if not isinstance(o, dict):
222 # if o is not a dict, then it's already deserialized
223 # in this case we should return it as is
224 return o
225
226 o = _convert(o)
227
228 # plain dict and no type hint
229 if CLASSNAME not in o and not type_hint or VERSION not in o:
230 return {str(k): deserialize(v, full) for k, v in o.items()}
231
232 # custom deserialization starts here
233 cls: Any
234 version = 0
235 value: Any = None
236 classname = ""
237
238 if type_hint:
239 cls = type_hint
240 classname = qualname(cls)
241 version = 0 # type hinting always sets version to 0
242 value = o
243
244 if CLASSNAME in o and VERSION in o:
245 classname, version, value = decode(o)
246
247 if not classname:
248 raise TypeError("classname cannot be empty")
249
250 # only return string representation
251 if not full:
252 return _stringify(classname, version, value)
253 if not _match(classname) and classname not in _extra_allowed:
254 raise ImportError(
255 f"{classname} was not found in allow list for deserialization imports. "
256 f"To allow it, add it to allowed_deserialization_classes in the configuration"
257 )
258
259 cls = import_string(classname)
260
261 # registered deserializer
262 if classname in _deserializers:
263 return _deserializers[classname].deserialize(classname, version, deserialize(value))
264
265 # class has deserialization function
266 if hasattr(cls, "deserialize"):
267 return getattr(cls, "deserialize")(deserialize(value), version)
268
269 # attr or dataclass or pydantic
270 if attr.has(cls) or dataclasses.is_dataclass(cls) or _is_pydantic(cls):
271 class_version = getattr(cls, "__version__", 0)
272 if int(version) > class_version:
273 raise TypeError(
274 "serialized version of %s is newer than module version (%s > %s)",
275 classname,
276 version,
277 class_version,
278 )
279
280 return cls(**deserialize(value))
281
282 # no deserializer available
283 raise TypeError(f"No deserializer found for {classname}")
284
285
286def _convert(old: dict) -> dict:
287 """Convert an old style serialization to new style."""
288 if OLD_TYPE in old and OLD_DATA in old:
289 # Return old style dicts directly as they do not need wrapping
290 if old[OLD_TYPE] == OLD_DICT:
291 return old[OLD_DATA]
292 else:
293 return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA]}
294
295 return old
296
297
298def _match(classname: str) -> bool:
299 """Check if the given classname matches a path pattern either using glob format or regexp format."""
300 return _match_glob(classname) or _match_regexp(classname)
301
302
303@functools.lru_cache(maxsize=None)
304def _match_glob(classname: str):
305 """Check if the given classname matches a pattern from allowed_deserialization_classes using glob syntax."""
306 patterns = _get_patterns()
307 return any(fnmatch(classname, p.pattern) for p in patterns)
308
309
310@functools.lru_cache(maxsize=None)
311def _match_regexp(classname: str):
312 """Check if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp."""
313 patterns = _get_regexp_patterns()
314 return any(p.match(classname) is not None for p in patterns)
315
316
317def _stringify(classname: str, version: int, value: T | None) -> str:
318 """Convert a previously serialized object in a somewhat human-readable format.
319
320 This function is not designed to be exact, and will not extensively traverse
321 the whole tree of an object.
322 """
323 if classname in _stringifiers:
324 return _stringifiers[classname].stringify(classname, version, value)
325
326 s = f"{classname}@version={version}("
327 if isinstance(value, _primitives):
328 s += f"{value}"
329 elif isinstance(value, _builtin_collections):
330 # deserialized values can be != str
331 s += ",".join(str(deserialize(value, full=False)))
332 elif isinstance(value, dict):
333 s += ",".join(f"{k}={deserialize(v, full=False)}" for k, v in value.items())
334 s += ")"
335
336 return s
337
338
339def _is_pydantic(cls: Any) -> bool:
340 """Return True if the class is a pydantic model.
341
342 Checking is done by attributes as it is significantly faster than
343 using isinstance.
344 """
345 return hasattr(cls, "model_config") and hasattr(cls, "model_fields") and hasattr(cls, "model_fields_set")
346
347
348def _is_namedtuple(cls: Any) -> bool:
349 """Return True if the class is a namedtuple.
350
351 Checking is done by attributes as it is significantly faster than
352 using isinstance.
353 """
354 return hasattr(cls, "_asdict") and hasattr(cls, "_fields") and hasattr(cls, "_field_defaults")
355
356
357def _register():
358 """Register builtin serializers and deserializers for types that don't have any themselves."""
359 _serializers.clear()
360 _deserializers.clear()
361 _stringifiers.clear()
362
363 with Stats.timer("serde.load_serializers") as timer:
364 for _, name, _ in iter_namespace(airflow.serialization.serializers):
365 name = import_module(name)
366 for s in getattr(name, "serializers", ()):
367 if not isinstance(s, str):
368 s = qualname(s)
369 if s in _serializers and _serializers[s] != name:
370 raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}")
371 log.debug("registering %s for serialization", s)
372 _serializers[s] = name
373 for d in getattr(name, "deserializers", ()):
374 if not isinstance(d, str):
375 d = qualname(d)
376 if d in _deserializers and _deserializers[d] != name:
377 raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}")
378 log.debug("registering %s for deserialization", d)
379 _deserializers[d] = name
380 _extra_allowed.add(d)
381 for c in getattr(name, "stringifiers", ()):
382 if not isinstance(c, str):
383 c = qualname(c)
384 if c in _deserializers and _deserializers[c] != name:
385 raise AttributeError(f"duplicate {c} for stringifiers in {name} and {_stringifiers[c]}")
386 log.debug("registering %s for stringifying", c)
387 _stringifiers[c] = name
388
389 log.debug("loading serializers took %.3f seconds", timer.duration)
390
391
392@functools.lru_cache(maxsize=None)
393def _get_patterns() -> list[Pattern]:
394 return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes").split()]
395
396
397@functools.lru_cache(maxsize=None)
398def _get_regexp_patterns() -> list[Pattern]:
399 return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes_regexp").split()]
400
401
402_register()