1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10# http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18from __future__ import annotations
19
20import dataclasses
21import enum
22import functools
23import logging
24import re
25import sys
26from fnmatch import fnmatch
27from importlib import import_module
28from re import Pattern
29from typing import TYPE_CHECKING, Any, TypeVar, cast
30
31import attr
32
33import airflow.serialization.serializers
34from airflow._shared.module_loading import import_string, iter_namespace, qualname
35from airflow.configuration import conf
36from airflow.observability.stats import Stats
37from airflow.serialization.typing import is_pydantic_model
38
39if TYPE_CHECKING:
40 from types import ModuleType
41
42log = logging.getLogger(__name__)
43
44MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
45
46CLASSNAME = "__classname__"
47VERSION = "__version__"
48DATA = "__data__"
49SCHEMA_ID = "__id__"
50CACHE = "__cache__"
51
52OLD_TYPE = "__type"
53OLD_SOURCE = "__source"
54OLD_DATA = "__var"
55OLD_DICT = "dict"
56PYDANTIC_MODEL_QUALNAME = "pydantic.main.BaseModel"
57
58DEFAULT_VERSION = 0
59
60T = TypeVar("T", bool, float, int, dict, list, str, tuple, set)
61U = bool | float | int | dict | list | str | tuple | set
62S = list | tuple | set
63
64_serializers: dict[str, ModuleType] = {}
65_deserializers: dict[str, ModuleType] = {}
66_stringifiers: dict[str, ModuleType] = {}
67_extra_allowed: set[str] = set()
68
69_primitives = (int, bool, float, str)
70_builtin_collections = (frozenset, list, set, tuple) # dict is treated specially.
71
72
73def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]:
74 """Encode an object so it can be understood by the deserializer."""
75 return {CLASSNAME: cls, VERSION: version, DATA: data}
76
77
78def decode(d: dict[str, Any]) -> tuple[str, int, Any]:
79 classname = d[CLASSNAME]
80 version = d[VERSION]
81
82 if not isinstance(classname, str) or not isinstance(version, int):
83 raise ValueError(f"cannot decode {d!r}")
84
85 data = d.get(DATA)
86
87 return classname, version, data
88
89
90def serialize(o: object, depth: int = 0) -> U | None:
91 """
92 Serialize an object into a representation consisting only built-in types.
93
94 Primitives (int, float, bool, str) are returned as-is. Built-in collections
95 are iterated over, where it is assumed that keys in a dict can be represented
96 as str.
97
98 Values that are not of a built-in type are serialized if a serializer is
99 found for them. The order in which serializers are used is
100
101 1. A ``serialize`` function provided by the object.
102 2. A registered serializer in the namespace of ``airflow.serialization.serializers``
103 3. Annotations from attr or dataclass.
104
105 Limitations: attr and dataclass objects can lose type information for nested objects
106 as they do not store this when calling ``asdict``. This means that at deserialization values
107 will be deserialized as a dict as opposed to reinstating the object. Provide
108 your own serializer to work around this.
109
110 :param o: The object to serialize.
111 :param depth: Private tracker for nested serialization.
112 :raise TypeError: A serializer cannot be found.
113 :raise RecursionError: The object is too nested for the function to handle.
114 :return: A representation of ``o`` that consists of only built-in types.
115 """
116 if depth == MAX_RECURSION_DEPTH:
117 raise RecursionError("maximum recursion depth reached for serialization")
118
119 # None remains None
120 if o is None:
121 return o
122
123 if isinstance(o, list):
124 return [serialize(d, depth + 1) for d in o]
125
126 if isinstance(o, dict):
127 if CLASSNAME in o or SCHEMA_ID in o:
128 raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize")
129
130 return {str(k): serialize(v, depth + 1) for k, v in o.items()}
131
132 cls = type(o)
133 qn = qualname(o)
134 classname = None
135
136 # Serialize namedtuple like tuples
137 # We also override the classname returned by the builtin.py serializer. The classname
138 # has to be "builtins.tuple", so that the deserializer can deserialize the object into tuple.
139 if _is_namedtuple(o):
140 qn = "builtins.tuple"
141 classname = qn
142
143 if is_pydantic_model(o):
144 # to match the generic Pydantic serializer and deserializer in _serializers and _deserializers
145 qn = PYDANTIC_MODEL_QUALNAME
146 # the actual Pydantic model class to encode
147 classname = qualname(o)
148
149 # if there is a builtin serializer available use that
150 if qn in _serializers:
151 data, serialized_classname, version, is_serialized = _serializers[qn].serialize(o)
152 if is_serialized:
153 return encode(classname or serialized_classname, version, serialize(data, depth + 1))
154
155 # primitive types are returned as is
156 if isinstance(o, _primitives):
157 if isinstance(o, enum.Enum):
158 return o.value
159
160 return o
161
162 # custom serializers
163 dct = {
164 CLASSNAME: qn,
165 VERSION: getattr(cls, "__version__", DEFAULT_VERSION),
166 }
167
168 # object / class brings their own
169 if hasattr(o, "serialize"):
170 data = getattr(o, "serialize")()
171
172 # if we end up with a structure, ensure its values are serialized
173 if isinstance(data, dict):
174 data = serialize(data, depth + 1)
175
176 dct[DATA] = data
177 return dct
178
179 # dataclasses
180 if dataclasses.is_dataclass(cls):
181 # fixme: unfortunately using asdict with nested dataclasses it looses information
182 data = dataclasses.asdict(o) # type: ignore[call-overload]
183 dct[DATA] = serialize(data, depth + 1)
184 return dct
185
186 # attr annotated
187 if attr.has(cls):
188 # Only include attributes which we can pass back to the classes constructor
189 data = attr.asdict(cast("attr.AttrsInstance", o), recurse=False, filter=lambda a, v: a.init)
190 dct[DATA] = serialize(data, depth + 1)
191 return dct
192
193 raise TypeError(f"cannot serialize object of type {cls}")
194
195
196def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:
197 """
198 Deserialize an object of primitive type and uses an allow list to determine if a class can be loaded.
199
200 :param o: primitive to deserialize into an arbitrary object.
201 :param full: if False it will return a stringified representation
202 of an object and will not load any classes
203 :param type_hint: if set it will be used to help determine what
204 object to deserialize in. It does not override if another
205 specification is found
206 :return: object
207 """
208 if o is None:
209 return o
210
211 if isinstance(o, _primitives):
212 return o
213
214 # tuples, sets are included here for backwards compatibility
215 if isinstance(o, _builtin_collections):
216 col = [deserialize(d) for d in o]
217 if isinstance(o, tuple):
218 return tuple(col)
219
220 if isinstance(o, set):
221 return set(col)
222
223 return col
224
225 if not isinstance(o, dict):
226 # if o is not a dict, then it's already deserialized
227 # in this case we should return it as is
228 return o
229
230 o = _convert(o)
231
232 # plain dict and no type hint
233 if CLASSNAME not in o and not type_hint or VERSION not in o:
234 return {str(k): deserialize(v, full) for k, v in o.items()}
235
236 # custom deserialization starts here
237 cls: Any
238 version = 0
239 value: Any = None
240 classname = ""
241
242 if type_hint:
243 cls = type_hint
244 classname = qualname(cls)
245 version = 0 # type hinting always sets version to 0
246 value = o
247
248 if CLASSNAME in o and VERSION in o:
249 classname, version, value = decode(o)
250
251 if not classname:
252 raise TypeError("classname cannot be empty")
253
254 # only return string representation
255 if not full:
256 return _stringify(classname, version, value)
257 if not _match(classname) and classname not in _extra_allowed:
258 raise ImportError(
259 f"{classname} was not found in allow list for deserialization imports. "
260 f"To allow it, add it to allowed_deserialization_classes in the configuration"
261 )
262
263 cls = import_string(classname)
264
265 # registered deserializer
266 if classname in _deserializers:
267 return _deserializers[classname].deserialize(cls, version, deserialize(value))
268 if is_pydantic_model(cls):
269 if PYDANTIC_MODEL_QUALNAME in _deserializers:
270 return _deserializers[PYDANTIC_MODEL_QUALNAME].deserialize(cls, version, deserialize(value))
271
272 # class has deserialization function
273 if hasattr(cls, "deserialize"):
274 return getattr(cls, "deserialize")(deserialize(value), version)
275
276 # attr or dataclass
277 if attr.has(cls) or dataclasses.is_dataclass(cls):
278 class_version = getattr(cls, "__version__", 0)
279 if int(version) > class_version:
280 raise TypeError(
281 "serialized version of %s is newer than module version (%s > %s)",
282 classname,
283 version,
284 class_version,
285 )
286
287 deserialize_value = deserialize(value)
288 if not isinstance(deserialize_value, dict):
289 raise TypeError(
290 f"deserialized value for {classname} is not a dict, got {type(deserialize_value)}"
291 )
292 return cls(**deserialize_value) # type: ignore[operator]
293
294 # no deserializer available
295 raise TypeError(f"No deserializer found for {classname}")
296
297
298def _convert(old: dict) -> dict:
299 """Convert an old style serialization to new style."""
300 if OLD_TYPE in old and OLD_DATA in old:
301 # Return old style dicts directly as they do not need wrapping
302 if old[OLD_TYPE] == OLD_DICT:
303 return old[OLD_DATA]
304 return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA]}
305
306 return old
307
308
309def _match(classname: str) -> bool:
310 """Check if the given classname matches a path pattern either using glob format or regexp format."""
311 return _match_glob(classname) or _match_regexp(classname)
312
313
314@functools.cache
315def _match_glob(classname: str):
316 """Check if the given classname matches a pattern from allowed_deserialization_classes using glob syntax."""
317 patterns = _get_patterns()
318 return any(fnmatch(classname, p.pattern) for p in patterns)
319
320
321@functools.cache
322def _match_regexp(classname: str):
323 """Check if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp."""
324 patterns = _get_regexp_patterns()
325 return any(p.match(classname) is not None for p in patterns)
326
327
328def _stringify(classname: str, version: int, value: T | None) -> str:
329 """
330 Convert a previously serialized object in a somewhat human-readable format.
331
332 This function is not designed to be exact, and will not extensively traverse
333 the whole tree of an object.
334 """
335 if classname in _stringifiers:
336 return _stringifiers[classname].stringify(classname, version, value)
337
338 s = f"{classname}@version={version}("
339 if isinstance(value, _primitives):
340 s += f"{value}"
341 elif isinstance(value, _builtin_collections):
342 # deserialized values can be != str
343 s += ",".join(str(deserialize(value, full=False)))
344 elif isinstance(value, dict):
345 s += ",".join(f"{k}={deserialize(v, full=False)}" for k, v in value.items())
346 s += ")"
347
348 return s
349
350
351def _is_namedtuple(cls: Any) -> bool:
352 """
353 Return True if the class is a namedtuple.
354
355 Checking is done by attributes as it is significantly faster than
356 using isinstance.
357 """
358 return hasattr(cls, "_asdict") and hasattr(cls, "_fields") and hasattr(cls, "_field_defaults")
359
360
361def _register():
362 """Register builtin serializers and deserializers for types that don't have any themselves."""
363 _serializers.clear()
364 _deserializers.clear()
365 _stringifiers.clear()
366
367 with Stats.timer("serde.load_serializers") as timer:
368 for _, module_name, _ in iter_namespace(airflow.serialization.serializers):
369 module = import_module(module_name)
370 for serializers in getattr(module, "serializers", ()):
371 s_qualname = serializers if isinstance(serializers, str) else qualname(serializers)
372 if s_qualname in _serializers and _serializers[s_qualname] != module:
373 raise AttributeError(
374 f"duplicate {s_qualname} for serialization in {module} and {_serializers[s_qualname]}"
375 )
376 log.debug("registering %s for serialization", s_qualname)
377 _serializers[s_qualname] = module
378 for deserializers in getattr(module, "deserializers", ()):
379 d_qualname = deserializers if isinstance(deserializers, str) else qualname(deserializers)
380 if d_qualname in _deserializers and _deserializers[d_qualname] != module:
381 raise AttributeError(
382 f"duplicate {d_qualname} for deserialization in {module} and {_deserializers[d_qualname]}"
383 )
384 log.debug("registering %s for deserialization", d_qualname)
385 _deserializers[d_qualname] = module
386 _extra_allowed.add(d_qualname)
387 for stringifiers in getattr(module, "stringifiers", ()):
388 c_qualname = stringifiers if isinstance(stringifiers, str) else qualname(stringifiers)
389 if c_qualname in _deserializers and _deserializers[c_qualname] != module:
390 raise AttributeError(
391 f"duplicate {c_qualname} for stringifiers in {module} and {_stringifiers[c_qualname]}"
392 )
393 log.debug("registering %s for stringifying", c_qualname)
394 _stringifiers[c_qualname] = module
395
396 log.debug("loading serializers took %.3f seconds", timer.duration)
397
398
399@functools.cache
400def _get_patterns() -> list[Pattern]:
401 return [re.compile(p) for p in conf.get("core", "allowed_deserialization_classes").split()]
402
403
404@functools.cache
405def _get_regexp_patterns() -> list[Pattern]:
406 return [re.compile(p) for p in conf.get("core", "allowed_deserialization_classes_regexp").split()]
407
408
409_register()