1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements. See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership. The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License. You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied. See the License for the
15# specific language governing permissions and limitations
16# under the License.
17from __future__ import annotations
18
19import dataclasses
20import enum
21import functools
22import logging
23import re
24import sys
25from fnmatch import fnmatch
26from importlib import import_module
27from re import Pattern
28from typing import TYPE_CHECKING, Any, TypeVar, cast, overload
29
30import attr
31
32from airflow.sdk._shared.module_loading import import_string, iter_namespace, qualname
33from airflow.sdk.configuration import conf
34from airflow.sdk.observability.stats import Stats
35from airflow.sdk.serde.typing import is_pydantic_model
36
37if TYPE_CHECKING:
38 from types import ModuleType
39
40log = logging.getLogger(__name__)
41
42MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1
43
44CLASSNAME = "__classname__"
45VERSION = "__version__"
46DATA = "__data__"
47SCHEMA_ID = "__id__"
48CACHE = "__cache__"
49
50OLD_TYPE = "__type"
51OLD_SOURCE = "__source"
52OLD_DATA = "__var"
53OLD_DICT = "dict"
54PYDANTIC_MODEL_QUALNAME = "pydantic.main.BaseModel"
55
56DEFAULT_VERSION = 0
57
58T = TypeVar("T", bool, float, int, dict, list, str, tuple, set)
59U = bool | float | int | dict | list | str | tuple | set
60S = list | tuple | set
61
62_serializers: dict[str, ModuleType] = {}
63_deserializers: dict[str, ModuleType] = {}
64_stringifiers: dict[str, ModuleType] = {}
65_extra_allowed: set[str] = set()
66
67_primitives = (int, bool, float, str)
68_builtin_collections = (frozenset, list, set, tuple) # dict is treated specially.
69
70
71def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]:
72 """Encode an object so it can be understood by the deserializer."""
73 return {CLASSNAME: cls, VERSION: version, DATA: data}
74
75
76def decode(d: dict[str, Any]) -> tuple[str, int, Any]:
77 classname = d[CLASSNAME]
78 version = d[VERSION]
79
80 if not isinstance(classname, str) or not isinstance(version, int):
81 raise ValueError(f"cannot decode {d!r}")
82
83 data = d.get(DATA)
84
85 return classname, version, data
86
87
88@overload
89def serialize(o: dict, depth: int = 0) -> dict: ...
90@overload
91def serialize(o: None, depth: int = 0) -> None: ...
92@overload
93def serialize(o: object, depth: int = 0) -> U | None: ...
94
95
96def serialize(o: object, depth: int = 0) -> U | None:
97 """
98 Serialize an object into a representation consisting only built-in types.
99
100 Primitives (int, float, bool, str) are returned as-is. Built-in collections
101 are iterated over, where it is assumed that keys in a dict can be represented
102 as str.
103
104 Values that are not of a built-in type are serialized if a serializer is
105 found for them. The order in which serializers are used is
106
107 1. A ``serialize`` function provided by the object.
108 2. A registered serializer in the namespace of ``airflow.sdk.serde.serializers``
109 3. Annotations from attr or dataclass.
110
111 Limitations: attr and dataclass objects can lose type information for nested objects
112 as they do not store this when calling ``asdict``. This means that at deserialization values
113 will be deserialized as a dict as opposed to reinstating the object. Provide
114 your own serializer to work around this.
115
116 :param o: The object to serialize.
117 :param depth: Private tracker for nested serialization.
118 :raise TypeError: A serializer cannot be found.
119 :raise RecursionError: The object is too nested for the function to handle.
120 :return: A representation of ``o`` that consists of only built-in types.
121 """
122 if depth == MAX_RECURSION_DEPTH:
123 raise RecursionError("maximum recursion depth reached for serialization")
124
125 # None remains None
126 if o is None:
127 return o
128
129 if isinstance(o, list):
130 return [serialize(d, depth + 1) for d in o]
131
132 if isinstance(o, dict):
133 if CLASSNAME in o or SCHEMA_ID in o:
134 raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize")
135
136 return {str(k): serialize(v, depth + 1) for k, v in o.items()}
137
138 cls = type(o)
139 qn = qualname(o)
140 classname = None
141
142 # Serialize namedtuple like tuples
143 # We also override the classname returned by the builtin.py serializer. The classname
144 # has to be "builtins.tuple", so that the deserializer can deserialize the object into tuple.
145 if _is_namedtuple(o):
146 qn = "builtins.tuple"
147 classname = qn
148
149 if is_pydantic_model(o):
150 # to match the generic Pydantic serializer and deserializer in _serializers and _deserializers
151 qn = PYDANTIC_MODEL_QUALNAME
152 # the actual Pydantic model class to encode
153 classname = qualname(o)
154
155 # if there is a builtin serializer available use that
156 if qn in _serializers:
157 data, serialized_classname, version, is_serialized = _serializers[qn].serialize(o)
158 if is_serialized:
159 return encode(classname or serialized_classname, version, serialize(data, depth + 1))
160
161 # primitive types are returned as is
162 if isinstance(o, _primitives):
163 if isinstance(o, enum.Enum):
164 return o.value
165
166 return o
167
168 # custom serializers
169 dct = {
170 CLASSNAME: qn,
171 VERSION: getattr(cls, "__version__", DEFAULT_VERSION),
172 }
173
174 # object / class brings their own
175 if hasattr(o, "serialize"):
176 data = getattr(o, "serialize")()
177
178 # if we end up with a structure, ensure its values are serialized
179 if isinstance(data, dict):
180 data = serialize(data, depth + 1)
181
182 dct[DATA] = data
183 return dct
184
185 # dataclasses
186 if dataclasses.is_dataclass(cls):
187 # fixme: unfortunately using asdict with nested dataclasses it looses information
188 data = dataclasses.asdict(o) # type: ignore[call-overload]
189 dct[DATA] = serialize(data, depth + 1)
190 return dct
191
192 # attr annotated
193 if attr.has(cls):
194 # Only include attributes which we can pass back to the classes constructor
195 data = attr.asdict(cast("attr.AttrsInstance", o), recurse=False, filter=lambda a, v: a.init)
196 dct[DATA] = serialize(data, depth + 1)
197 return dct
198
199 raise TypeError(f"cannot serialize object of type {cls}")
200
201
202def deserialize(o: T | None, full=True, type_hint: Any = None) -> object:
203 """
204 Deserialize an object of primitive type and uses an allow list to determine if a class can be loaded.
205
206 :param o: primitive to deserialize into an arbitrary object.
207 :param full: if False it will return a stringified representation
208 of an object and will not load any classes
209 :param type_hint: if set it will be used to help determine what
210 object to deserialize in. It does not override if another
211 specification is found
212 :return: object
213 """
214 if o is None:
215 return o
216
217 if isinstance(o, _primitives):
218 return o
219
220 # tuples, sets are included here for backwards compatibility
221 if isinstance(o, _builtin_collections):
222 col = [deserialize(d) for d in o]
223 if isinstance(o, tuple):
224 return tuple(col)
225
226 if isinstance(o, set):
227 return set(col)
228
229 return col
230
231 if not isinstance(o, dict):
232 # if o is not a dict, then it's already deserialized
233 # in this case we should return it as is
234 return o
235
236 o = _convert(o)
237
238 # plain dict and no type hint
239 if CLASSNAME not in o and not type_hint or VERSION not in o:
240 return {str(k): deserialize(v, full) for k, v in o.items()}
241
242 # custom deserialization starts here
243 cls: Any
244 version = 0
245 value: Any = None
246 classname = ""
247
248 if type_hint:
249 cls = type_hint
250 classname = qualname(cls)
251 version = 0 # type hinting always sets version to 0
252 value = o
253
254 if CLASSNAME in o and VERSION in o:
255 classname, version, value = decode(o)
256
257 if not classname:
258 raise TypeError("classname cannot be empty")
259
260 # only return string representation
261 if not full:
262 return _stringify(classname, version, value)
263 if not _match(classname) and classname not in _extra_allowed:
264 raise ImportError(
265 f"{classname} was not found in allow list for deserialization imports. "
266 f"To allow it, add it to allowed_deserialization_classes in the configuration"
267 )
268
269 cls = import_string(classname)
270
271 # registered deserializer
272 if classname in _deserializers:
273 return _deserializers[classname].deserialize(cls, version, deserialize(value))
274 if is_pydantic_model(cls):
275 if PYDANTIC_MODEL_QUALNAME in _deserializers:
276 return _deserializers[PYDANTIC_MODEL_QUALNAME].deserialize(cls, version, deserialize(value))
277
278 # class has deserialization function
279 if hasattr(cls, "deserialize"):
280 return getattr(cls, "deserialize")(deserialize(value), version)
281
282 # attr or dataclass
283 if attr.has(cls) or dataclasses.is_dataclass(cls):
284 class_version = getattr(cls, "__version__", 0)
285 if int(version) > class_version:
286 raise TypeError(
287 "serialized version of %s is newer than module version (%s > %s)",
288 classname,
289 version,
290 class_version,
291 )
292
293 deserialize_value = deserialize(value)
294 if not isinstance(deserialize_value, dict):
295 raise TypeError(
296 f"deserialized value for {classname} is not a dict, got {type(deserialize_value)}"
297 )
298 return cls(**deserialize_value) # type: ignore[operator]
299
300 # no deserializer available
301 raise TypeError(f"No deserializer found for {classname}")
302
303
304def _convert(old: dict) -> dict:
305 """Convert an old style serialization to new style."""
306 if OLD_TYPE in old and OLD_DATA in old:
307 # Return old style dicts directly as they do not need wrapping
308 if old[OLD_TYPE] == OLD_DICT:
309 return old[OLD_DATA]
310 return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA]}
311
312 return old
313
314
315def _match(classname: str) -> bool:
316 """Check if the given classname matches a path pattern either using glob format or regexp format."""
317 return _match_glob(classname) or _match_regexp(classname)
318
319
320@functools.cache
321def _match_glob(classname: str):
322 """Check if the given classname matches a pattern from allowed_deserialization_classes using glob syntax."""
323 patterns = _get_patterns()
324 return any(fnmatch(classname, p.pattern) for p in patterns)
325
326
327@functools.cache
328def _match_regexp(classname: str):
329 """Check if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp."""
330 patterns = _get_regexp_patterns()
331 return any(p.match(classname) is not None for p in patterns)
332
333
334def _stringify(classname: str, version: int, value: T | None) -> str:
335 """
336 Convert a previously serialized object in a somewhat human-readable format.
337
338 This function is not designed to be exact, and will not extensively traverse
339 the whole tree of an object.
340 """
341 if classname in _stringifiers:
342 return _stringifiers[classname].stringify(classname, version, value)
343
344 s = f"{classname}@version={version}("
345 if isinstance(value, _primitives):
346 s += f"{value}"
347 elif isinstance(value, _builtin_collections):
348 # deserialized values can be != str
349 s += ",".join(str(deserialize(value, full=False)))
350 elif isinstance(value, dict):
351 s += ",".join(f"{k}={deserialize(v, full=False)}" for k, v in value.items())
352 s += ")"
353
354 return s
355
356
357def _is_namedtuple(cls: Any) -> bool:
358 """
359 Return True if the class is a namedtuple.
360
361 Checking is done by attributes as it is significantly faster than
362 using isinstance.
363 """
364 return hasattr(cls, "_asdict") and hasattr(cls, "_fields") and hasattr(cls, "_field_defaults")
365
366
367def _register():
368 """Register builtin serializers and deserializers for types that don't have any themselves."""
369 _serializers.clear()
370 _deserializers.clear()
371 _stringifiers.clear()
372
373 with Stats.timer("serde.load_serializers") as timer:
374 serializers_module = import_module("airflow.sdk.serde.serializers")
375 for _, module_name, _ in iter_namespace(serializers_module):
376 module = import_module(module_name)
377 for serializers in getattr(module, "serializers", ()):
378 s_qualname = serializers if isinstance(serializers, str) else qualname(serializers)
379 if s_qualname in _serializers and _serializers[s_qualname] != module:
380 raise AttributeError(
381 f"duplicate {s_qualname} for serialization in {module} and {_serializers[s_qualname]}"
382 )
383 log.debug("registering %s for serialization", s_qualname)
384 _serializers[s_qualname] = module
385 for deserializers in getattr(module, "deserializers", ()):
386 d_qualname = deserializers if isinstance(deserializers, str) else qualname(deserializers)
387 if d_qualname in _deserializers and _deserializers[d_qualname] != module:
388 raise AttributeError(
389 f"duplicate {d_qualname} for deserialization in {module} and {_deserializers[d_qualname]}"
390 )
391 log.debug("registering %s for deserialization", d_qualname)
392 _deserializers[d_qualname] = module
393 _extra_allowed.add(d_qualname)
394 for stringifiers in getattr(module, "stringifiers", ()):
395 c_qualname = stringifiers if isinstance(stringifiers, str) else qualname(stringifiers)
396 if c_qualname in _deserializers and _deserializers[c_qualname] != module:
397 raise AttributeError(
398 f"duplicate {c_qualname} for stringifiers in {module} and {_stringifiers[c_qualname]}"
399 )
400 log.debug("registering %s for stringifying", c_qualname)
401 _stringifiers[c_qualname] = module
402
403 log.debug("loading serializers took %.3f seconds", timer.duration)
404
405
406@functools.cache
407def _get_patterns() -> list[Pattern]:
408 return [re.compile(p) for p in conf.get("core", "allowed_deserialization_classes").split()]
409
410
411@functools.cache
412def _get_regexp_patterns() -> list[Pattern]:
413 return [re.compile(p) for p in conf.get("core", "allowed_deserialization_classes_regexp").split()]
414
415
416_register()