Coverage for /pythoncovmergedfiles/medio/medio/usr/local/lib/python3.8/site-packages/airflow/serialization/serde.py: 39%

Shortcuts on this page

r m x   toggle line displays

j k   next/prev highlighted chunk

0   (zero) top of page

1   (one) first highlighted chunk

205 statements  

1# 

2# Licensed to the Apache Software Foundation (ASF) under one 

3# or more contributor license agreements. See the NOTICE file 

4# distributed with this work for additional information 

5# regarding copyright ownership. The ASF licenses this file 

6# to you under the Apache License, Version 2.0 (the 

7# "License"); you may not use this file except in compliance 

8# with the License. You may obtain a copy of the License at 

9# 

10# http://www.apache.org/licenses/LICENSE-2.0 

11# 

12# Unless required by applicable law or agreed to in writing, 

13# software distributed under the License is distributed on an 

14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 

15# KIND, either express or implied. See the License for the 

16# specific language governing permissions and limitations 

17# under the License. 

18from __future__ import annotations 

19 

20import dataclasses 

21import enum 

22import functools 

23import logging 

24import sys 

25from fnmatch import fnmatch 

26from importlib import import_module 

27from typing import TYPE_CHECKING, Any, Pattern, TypeVar, Union, cast 

28 

29import attr 

30import re2 

31 

32import airflow.serialization.serializers 

33from airflow.configuration import conf 

34from airflow.stats import Stats 

35from airflow.utils.module_loading import import_string, iter_namespace, qualname 

36 

37if TYPE_CHECKING: 

38 from types import ModuleType 

39 

40log = logging.getLogger(__name__) 

41 

42MAX_RECURSION_DEPTH = sys.getrecursionlimit() - 1 

43 

44CLASSNAME = "__classname__" 

45VERSION = "__version__" 

46DATA = "__data__" 

47SCHEMA_ID = "__id__" 

48CACHE = "__cache__" 

49 

50OLD_TYPE = "__type" 

51OLD_SOURCE = "__source" 

52OLD_DATA = "__var" 

53OLD_DICT = "dict" 

54 

55DEFAULT_VERSION = 0 

56 

57T = TypeVar("T", bool, float, int, dict, list, str, tuple, set) 

58U = Union[bool, float, int, dict, list, str, tuple, set] 

59S = Union[list, tuple, set] 

60 

61_serializers: dict[str, ModuleType] = {} 

62_deserializers: dict[str, ModuleType] = {} 

63_stringifiers: dict[str, ModuleType] = {} 

64_extra_allowed: set[str] = set() 

65 

66_primitives = (int, bool, float, str) 

67_builtin_collections = (frozenset, list, set, tuple) # dict is treated specially. 

68 

69 

70def encode(cls: str, version: int, data: T) -> dict[str, str | int | T]: 

71 """Encode an object so it can be understood by the deserializer.""" 

72 return {CLASSNAME: cls, VERSION: version, DATA: data} 

73 

74 

75def decode(d: dict[str, Any]) -> tuple[str, int, Any]: 

76 classname = d[CLASSNAME] 

77 version = d[VERSION] 

78 

79 if not isinstance(classname, str) or not isinstance(version, int): 

80 raise ValueError(f"cannot decode {d!r}") 

81 

82 data = d.get(DATA) 

83 

84 return classname, version, data 

85 

86 

87def serialize(o: object, depth: int = 0) -> U | None: 

88 """Serialize an object into a representation consisting only built-in types. 

89 

90 Primitives (int, float, bool, str) are returned as-is. Built-in collections 

91 are iterated over, where it is assumed that keys in a dict can be represented 

92 as str. 

93 

94 Values that are not of a built-in type are serialized if a serializer is 

95 found for them. The order in which serializers are used is 

96 

97 1. A ``serialize`` function provided by the object. 

98 2. A registered serializer in the namespace of ``airflow.serialization.serializers`` 

99 3. Annotations from attr or dataclass. 

100 

101 Limitations: attr and dataclass objects can lose type information for nested objects 

102 as they do not store this when calling ``asdict``. This means that at deserialization values 

103 will be deserialized as a dict as opposed to reinstating the object. Provide 

104 your own serializer to work around this. 

105 

106 :param o: The object to serialize. 

107 :param depth: Private tracker for nested serialization. 

108 :raise TypeError: A serializer cannot be found. 

109 :raise RecursionError: The object is too nested for the function to handle. 

110 :return: A representation of ``o`` that consists of only built-in types. 

111 """ 

112 if depth == MAX_RECURSION_DEPTH: 

113 raise RecursionError("maximum recursion depth reached for serialization") 

114 

115 # None remains None 

116 if o is None: 

117 return o 

118 

119 # primitive types are returned as is 

120 if isinstance(o, _primitives): 

121 if isinstance(o, enum.Enum): 

122 return o.value 

123 

124 return o 

125 

126 if isinstance(o, list): 

127 return [serialize(d, depth + 1) for d in o] 

128 

129 if isinstance(o, dict): 

130 if CLASSNAME in o or SCHEMA_ID in o: 

131 raise AttributeError(f"reserved key {CLASSNAME} or {SCHEMA_ID} found in dict to serialize") 

132 

133 return {str(k): serialize(v, depth + 1) for k, v in o.items()} 

134 

135 cls = type(o) 

136 qn = qualname(o) 

137 classname = None 

138 

139 # Serialize namedtuple like tuples 

140 # We also override the classname returned by the builtin.py serializer. The classname 

141 # has to be "builtins.tuple", so that the deserializer can deserialize the object into tuple. 

142 if _is_namedtuple(o): 

143 qn = "builtins.tuple" 

144 classname = qn 

145 

146 # if there is a builtin serializer available use that 

147 if qn in _serializers: 

148 data, serialized_classname, version, is_serialized = _serializers[qn].serialize(o) 

149 if is_serialized: 

150 return encode(classname or serialized_classname, version, serialize(data, depth + 1)) 

151 

152 # custom serializers 

153 dct = { 

154 CLASSNAME: qn, 

155 VERSION: getattr(cls, "__version__", DEFAULT_VERSION), 

156 } 

157 

158 # object / class brings their own 

159 if hasattr(o, "serialize"): 

160 data = getattr(o, "serialize")() 

161 

162 # if we end up with a structure, ensure its values are serialized 

163 if isinstance(data, dict): 

164 data = serialize(data, depth + 1) 

165 

166 dct[DATA] = data 

167 return dct 

168 

169 # pydantic models are recursive 

170 if _is_pydantic(cls): 

171 data = o.model_dump() # type: ignore[attr-defined] 

172 dct[DATA] = serialize(data, depth + 1) 

173 return dct 

174 

175 # dataclasses 

176 if dataclasses.is_dataclass(cls): 

177 # fixme: unfortunately using asdict with nested dataclasses it looses information 

178 data = dataclasses.asdict(o) # type: ignore[call-overload] 

179 dct[DATA] = serialize(data, depth + 1) 

180 return dct 

181 

182 # attr annotated 

183 if attr.has(cls): 

184 # Only include attributes which we can pass back to the classes constructor 

185 data = attr.asdict(cast(attr.AttrsInstance, o), recurse=False, filter=lambda a, v: a.init) 

186 dct[DATA] = serialize(data, depth + 1) 

187 return dct 

188 

189 raise TypeError(f"cannot serialize object of type {cls}") 

190 

191 

192def deserialize(o: T | None, full=True, type_hint: Any = None) -> object: 

193 """ 

194 Deserialize an object of primitive type and uses an allow list to determine if a class can be loaded. 

195 

196 :param o: primitive to deserialize into an arbitrary object. 

197 :param full: if False it will return a stringified representation 

198 of an object and will not load any classes 

199 :param type_hint: if set it will be used to help determine what 

200 object to deserialize in. It does not override if another 

201 specification is found 

202 :return: object 

203 """ 

204 if o is None: 

205 return o 

206 

207 if isinstance(o, _primitives): 

208 return o 

209 

210 # tuples, sets are included here for backwards compatibility 

211 if isinstance(o, _builtin_collections): 

212 col = [deserialize(d) for d in o] 

213 if isinstance(o, tuple): 

214 return tuple(col) 

215 

216 if isinstance(o, set): 

217 return set(col) 

218 

219 return col 

220 

221 if not isinstance(o, dict): 

222 # if o is not a dict, then it's already deserialized 

223 # in this case we should return it as is 

224 return o 

225 

226 o = _convert(o) 

227 

228 # plain dict and no type hint 

229 if CLASSNAME not in o and not type_hint or VERSION not in o: 

230 return {str(k): deserialize(v, full) for k, v in o.items()} 

231 

232 # custom deserialization starts here 

233 cls: Any 

234 version = 0 

235 value: Any = None 

236 classname = "" 

237 

238 if type_hint: 

239 cls = type_hint 

240 classname = qualname(cls) 

241 version = 0 # type hinting always sets version to 0 

242 value = o 

243 

244 if CLASSNAME in o and VERSION in o: 

245 classname, version, value = decode(o) 

246 

247 if not classname: 

248 raise TypeError("classname cannot be empty") 

249 

250 # only return string representation 

251 if not full: 

252 return _stringify(classname, version, value) 

253 if not _match(classname) and classname not in _extra_allowed: 

254 raise ImportError( 

255 f"{classname} was not found in allow list for deserialization imports. " 

256 f"To allow it, add it to allowed_deserialization_classes in the configuration" 

257 ) 

258 

259 cls = import_string(classname) 

260 

261 # registered deserializer 

262 if classname in _deserializers: 

263 return _deserializers[classname].deserialize(classname, version, deserialize(value)) 

264 

265 # class has deserialization function 

266 if hasattr(cls, "deserialize"): 

267 return getattr(cls, "deserialize")(deserialize(value), version) 

268 

269 # attr or dataclass or pydantic 

270 if attr.has(cls) or dataclasses.is_dataclass(cls) or _is_pydantic(cls): 

271 class_version = getattr(cls, "__version__", 0) 

272 if int(version) > class_version: 

273 raise TypeError( 

274 "serialized version of %s is newer than module version (%s > %s)", 

275 classname, 

276 version, 

277 class_version, 

278 ) 

279 

280 return cls(**deserialize(value)) 

281 

282 # no deserializer available 

283 raise TypeError(f"No deserializer found for {classname}") 

284 

285 

286def _convert(old: dict) -> dict: 

287 """Convert an old style serialization to new style.""" 

288 if OLD_TYPE in old and OLD_DATA in old: 

289 # Return old style dicts directly as they do not need wrapping 

290 if old[OLD_TYPE] == OLD_DICT: 

291 return old[OLD_DATA] 

292 else: 

293 return {CLASSNAME: old[OLD_TYPE], VERSION: DEFAULT_VERSION, DATA: old[OLD_DATA]} 

294 

295 return old 

296 

297 

298def _match(classname: str) -> bool: 

299 """Check if the given classname matches a path pattern either using glob format or regexp format.""" 

300 return _match_glob(classname) or _match_regexp(classname) 

301 

302 

303@functools.lru_cache(maxsize=None) 

304def _match_glob(classname: str): 

305 """Check if the given classname matches a pattern from allowed_deserialization_classes using glob syntax.""" 

306 patterns = _get_patterns() 

307 return any(fnmatch(classname, p.pattern) for p in patterns) 

308 

309 

310@functools.lru_cache(maxsize=None) 

311def _match_regexp(classname: str): 

312 """Check if the given classname matches a pattern from allowed_deserialization_classes_regexp using regexp.""" 

313 patterns = _get_regexp_patterns() 

314 return any(p.match(classname) is not None for p in patterns) 

315 

316 

317def _stringify(classname: str, version: int, value: T | None) -> str: 

318 """Convert a previously serialized object in a somewhat human-readable format. 

319 

320 This function is not designed to be exact, and will not extensively traverse 

321 the whole tree of an object. 

322 """ 

323 if classname in _stringifiers: 

324 return _stringifiers[classname].stringify(classname, version, value) 

325 

326 s = f"{classname}@version={version}(" 

327 if isinstance(value, _primitives): 

328 s += f"{value}" 

329 elif isinstance(value, _builtin_collections): 

330 # deserialized values can be != str 

331 s += ",".join(str(deserialize(value, full=False))) 

332 elif isinstance(value, dict): 

333 s += ",".join(f"{k}={deserialize(v, full=False)}" for k, v in value.items()) 

334 s += ")" 

335 

336 return s 

337 

338 

339def _is_pydantic(cls: Any) -> bool: 

340 """Return True if the class is a pydantic model. 

341 

342 Checking is done by attributes as it is significantly faster than 

343 using isinstance. 

344 """ 

345 return hasattr(cls, "model_config") and hasattr(cls, "model_fields") and hasattr(cls, "model_fields_set") 

346 

347 

348def _is_namedtuple(cls: Any) -> bool: 

349 """Return True if the class is a namedtuple. 

350 

351 Checking is done by attributes as it is significantly faster than 

352 using isinstance. 

353 """ 

354 return hasattr(cls, "_asdict") and hasattr(cls, "_fields") and hasattr(cls, "_field_defaults") 

355 

356 

357def _register(): 

358 """Register builtin serializers and deserializers for types that don't have any themselves.""" 

359 _serializers.clear() 

360 _deserializers.clear() 

361 _stringifiers.clear() 

362 

363 with Stats.timer("serde.load_serializers") as timer: 

364 for _, name, _ in iter_namespace(airflow.serialization.serializers): 

365 name = import_module(name) 

366 for s in getattr(name, "serializers", ()): 

367 if not isinstance(s, str): 

368 s = qualname(s) 

369 if s in _serializers and _serializers[s] != name: 

370 raise AttributeError(f"duplicate {s} for serialization in {name} and {_serializers[s]}") 

371 log.debug("registering %s for serialization", s) 

372 _serializers[s] = name 

373 for d in getattr(name, "deserializers", ()): 

374 if not isinstance(d, str): 

375 d = qualname(d) 

376 if d in _deserializers and _deserializers[d] != name: 

377 raise AttributeError(f"duplicate {d} for deserialization in {name} and {_serializers[d]}") 

378 log.debug("registering %s for deserialization", d) 

379 _deserializers[d] = name 

380 _extra_allowed.add(d) 

381 for c in getattr(name, "stringifiers", ()): 

382 if not isinstance(c, str): 

383 c = qualname(c) 

384 if c in _deserializers and _deserializers[c] != name: 

385 raise AttributeError(f"duplicate {c} for stringifiers in {name} and {_stringifiers[c]}") 

386 log.debug("registering %s for stringifying", c) 

387 _stringifiers[c] = name 

388 

389 log.debug("loading serializers took %.3f seconds", timer.duration) 

390 

391 

392@functools.lru_cache(maxsize=None) 

393def _get_patterns() -> list[Pattern]: 

394 return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes").split()] 

395 

396 

397@functools.lru_cache(maxsize=None) 

398def _get_regexp_patterns() -> list[Pattern]: 

399 return [re2.compile(p) for p in conf.get("core", "allowed_deserialization_classes_regexp").split()] 

400 

401 

402_register()